diff --git a/.gitignore b/.gitignore index 02c59aaca..ba64bf74c 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ workspace.xml .vscode/ .claude/ .windsurf/ +.codebuddy/ # VitePress / frontend build artifacts www/.vitepress/cache/ www/.vitepress/dist/ diff --git a/README.md b/README.md index 9d23fb9f9..ac6e2626b 100644 --- a/README.md +++ b/README.md @@ -131,10 +131,6 @@ neocode -w /path/to/your/project neocode web ``` -标签发布版已经将 Web UI 的 `web/dist` 内嵌进 `neocode` 二进制,执行 `neocode web` 时不再要求用户机器安装 Node.js 或 npm。如果你在源码仓库里运行 `go run ./cmd/neocode web`,当本地缺少 `web/dist` 时仍会自动尝试构建前端。 - -Electron 桌面端发布图标使用已提交的 `web/build/icon.png`、`web/build/icon.ico` 和 `web/build/icon.icns`。只有替换 `web/build/icon.png` 源图时,才需要在 `web/` 目录手动运行 `npm run generate:icons` 重新生成 Windows 与 macOS 图标;该命令在 Windows 使用 PowerShell/.NET 图像能力,在 macOS 使用 `sips`,在 Linux 需要 ImageMagick 的 `magick` 命令。 - ### 4. Web / 飞书快速入口 ```bash diff --git a/docs/architecture/architecture-v1.md b/docs/architecture/architecture-v1.md index d580bb59d..857f9efe9 100644 --- a/docs/architecture/architecture-v1.md +++ b/docs/architecture/architecture-v1.md @@ -524,7 +524,7 @@ graph LR | **模型适配器** | 归一化不同厂商的 Chat API 为统一的 `Generate()` + `EstimateInputTokens()` 接口;将厂商特定的流式响应格式转换为标准 `StreamEvent` | 厂商差异不泄漏到 Runtime;每个 Adapter 独立测试 | Provider | | **工具执行器** | 暴露工具的 Schema 供模型选择;校验参数并执行工具调用;在每次执行前经过安全守卫的权限裁决 | 所有模型可调用的能力收敛于此角色;不在 Runtime 或客户端中绕过 | Tools (Manager) | | **安全守卫** | 基于策略规则(Priority 排序)裁决每个操作的 allow/deny/ask 决策;校验工作区边界(路径穿越检测、Symlink 解析);管理会话级权限记忆 | 位于工具执行的关键路径上,不可跳过 | Security Engine | -| **上下文构建器** | 按会话状态 + 预算阈值动态组装 System Prompt 和消息列表;执行上下文压缩(MicroCompact / Full Compact / Trim) | 压缩时不丢失 System Prompt 和 Pin 标记的关键消息;组装顺序稳定 | Context | +| **上下文构建器** | 按会话状态 + 预算阈值动态组装 System Prompt 和消息列表;执行上下文压缩(Full Compact / Trim) | 压缩时不丢失 System Prompt;组装顺序稳定 | Context | | **状态管理者** | 持久化会话消息历史(SQLite);管理 Checkpoint 快照的创建/恢复/修剪;执行过期会话的自动清理 | 同会话并发写串行化(sessionLock);消息追加原子化 | Session | | **技能注入器** | 从文件系统扫描 SKILL.md;管理会话级 Skill 激活状态;按激活列表将 Skill Prompt 注入 System Prompt 的技能段落 | project 层覆盖 global 层(同名去重);单文件大小限制 1MB | Skills | | **远程执行代理** | 在远程/本机独立进程中接收 Gateway 的工具执行请求;校验 Capability Token;在本地完成工具执行并返回结果 | 主动连接 Gateway(反向连接);不开放入站端口;受 WorkdirAllowlist 限制 | Runner | @@ -691,11 +691,7 @@ NeoCode 在多处预留了扩展点。本节集中描述:**哪里可以扩展 | `filesystem_edit` | 文件系统 | 基于字符串精确替换的原地编辑 | | `filesystem_glob` | 文件系统 | 文件名模式匹配 | | `filesystem_grep` | 文件系统 | 文件内容正则搜索 | -| `filesystem_copy_file` | 文件系统 | 复制文件 | -| `filesystem_move_file` | 文件系统 | 移动/重命名文件 | -| `filesystem_create_dir` | 文件系统 | 创建目录 | | `filesystem_delete_file` | 文件系统 | 删除文件 | -| `filesystem_remove_dir` | 文件系统 | 删除目录 | | `codebase_read` | 代码库 | 读取代码文件(含语义增强) | | `codebase_search_text` | 代码库 | 基于文本搜索代码库 | | `codebase_search_symbol` | 代码库 | 基于 Tree-sitter 的跨语言符号搜索 | @@ -722,7 +718,7 @@ D4(工具执行可控)要求 AI 的写操作可回滚。Checkpoint 在每次 ### 8.6 Context(Prompt 构建与上下文压缩) -**存在理由:** "AI 看到了什么"是一个独立于"AI 怎么推理"的架构关注点。将 Context 从 Runtime 中分离出来,意味着 Compact 策略(MicroCompact / Full Compact / Trim)可以独立演进,不需要修改推理循环。 +**存在理由:** "AI 看到了什么"是一个独立于"AI 怎么推理"的架构关注点。将 Context 从 Runtime 中分离出来,意味着 Compact 策略(Full Compact / Trim)可以独立演进,不需要修改推理循环。 **拥有的决策权:** System Prompt 的组装顺序(`corePrompt → capabilities → rules → taskState → planModeContext → todos → skillPrompt → repositoryContext → systemState` 的固定顺序);Compact 何时触发、采用什么级别(Micro vs Full vs Trim);哪些消息不能被压缩(Pin 标记)。 @@ -831,7 +827,7 @@ sequenceDiagram **触发条件:** 每轮推理前 `prepareTurnBudgetSnapshot` 检测到 Token 消耗接近预算阈值(基于 Provider 的 `EstimateInputTokens` 估算 + 配置的 `compact_trigger_ratio`)。 -**参与组件:** Runtime → Context Builder → MicroCompact → Compact Runner (Provider) → Session Store +**参与组件:** Runtime → Context Builder → Compact Runner (Provider) → Session Store **流程:** @@ -845,8 +841,6 @@ sequenceDiagram CC-->>RT: needsCompact = true RT->>CC: Compact(input) - CC->>CC: MicroCompact: 对可压缩 tool_result 摘要化 - alt MicroCompact 不足 CC->>CC: Full Compact: CompactRunner.Generate() 生成结构化摘要 end CC->>CC: Trim: 裁剪最旧消息(保留 System Prompt + Pin 标记) @@ -859,8 +853,7 @@ sequenceDiagram | 级别 | 触发条件 | 操作 | 对上下文的影响 | |------|----------|------|---------------| -| **MicroCompact** | 单次 Tool Call 结果过大,导致本轮预算紧张 | 对单个 tool_result 内容摘要化(保留关键输出,丢弃冗长中间日志) | 仅影响当前工具结果,不改变历史 | -| **Full Compact** | MicroCompact 后仍超预算,或累计历史消息过多 | 将历史消息中可压缩的部分通过 LLM 生成结构化摘要,替换原始消息 | 历史消息被摘要替代,System Prompt + 最近 N 轮保留 | +| **Full Compact** | 累计历史消息过多 | 将历史消息中可压缩的部分通过 LLM 生成结构化摘要,替换原始消息 | 历史消息被摘要替代,System Prompt + 最近 N 轮保留 | **关键不变量:** - System Prompt(corePrompt + capabilities + rules + skillPrompt)永不参与压缩 @@ -1069,7 +1062,7 @@ sequenceDiagram opt 若 Token 预算接近阈值 RT->>CTX: Compact(session) - CTX->>CTX: MicroCompact(tool_results) → FullCompact(history) + CTX->>CTX: FullCompact(history) CTX->>SS: ReplaceTranscript() end @@ -1777,7 +1770,6 @@ SessionID(会话级) + RunID(单次运行级) |------|------| | **ReAct Loop** | Reasoning + Acting 循环:模型推理 → 解析工具调用 → 执行工具 → 回灌结果 → 继续推理,直到产出最终文本回复 | | **Compact** | 上下文压缩:当对话历史累积到接近 Token 预算上限时,自动将历史消息摘要化或裁剪,以释放上下文空间 | -| **MicroCompact** | 轻量级压缩:仅对单个 tool_result 内容做摘要化,不改变消息列表结构。是 Compact 的第一阶段 | | **StreamRelay** | 流式中继:Gateway 内部将 Runtime 的异步事件按 SessionID/RunID 广播到所有订阅客户端连接的 pub/sub 机制 | | **Checkpoint** | 代码版本快照:AI 执行写操作前自动创建的文件状态快照,支持恢复和 Diff 查看 | | **Human-in-the-loop** | 人机协作模式:AI 在执行可能危险的操作(如写文件、执行 Bash)前暂停,等待人类审批 | diff --git a/docs/architecture/architecture-v3.md b/docs/architecture/architecture-v3.md index f46941cd8..90ef79078 100644 --- a/docs/architecture/architecture-v3.md +++ b/docs/architecture/architecture-v3.md @@ -64,7 +64,7 @@ graph TD |------|----------|----------| | **LLM 输出不稳定** | 同样的 Prompt,不同模型(甚至同一模型的不同请求)可能产出完全不同的工具调用策略和代码质量 | Provider 归一化 + Compact 保持上下文一致性 | | **工具执行具有副作用** | 模型决定执行 `rm -rf` 或修改关键配置文件,后果不可撤销 | Security Engine 四层防御 + Checkpoint 自动快照 | -| **上下文窗口有限** | 模型的 context window 有硬上限(4K–200K tokens),长对话和大代码库必然超限 | Context 模块两级 Compact 策略(MicroCompact / Full Compact) | +| **上下文窗口有限** | 模型的 context window 有硬上限(4K–200K tokens),长对话和大代码库必然超限 | Context 模块 Compact 策略(Full Compact / Trim) | | **多轮任务需要状态管理** | 一次任务可能跨越数十轮推理,中间包含工具调用、审批暂停、错误重试,状态必须一致 | Session 持久化 + Runtime 集中管理会话状态 | | **多端接入的一致性** | TUI、Web、Desktop、飞书、CI 脚本需要用统一协议接入,且行为一致 | Gateway 作为唯一 RPC 边界 + JSON-RPC 2.0 标准协议 | @@ -169,7 +169,7 @@ flowchart TD **第一步:构建上下文。** Runtime 把当前会话状态(消息历史、Todo 列表、激活的 Skills、已批准的 Plan)交给 Context 模块。Context 按固定顺序组装 System Prompt——核心行为准则、工具能力列表、项目规则、当前任务状态、Plan 上下文——然后返回给 Runtime。 -**Runtime 为什么不自己拼 Prompt。** Prompt 的组装逻辑是一个独立的关注点。上下文压缩(Compact)的策略——什么时候触发、用 MicroCompact 还是 Full Compact、哪些消息不能裁剪——需要在 Context 模块内独立演进。如果 Runtime 内嵌了 Prompt 拼接,修改压缩策略就需要改推理循环,两者耦合。 +**Runtime 为什么不自己拼 Prompt。** Prompt 的组装逻辑是一个独立的关注点。上下文压缩(Compact)的策略——什么时候触发、哪些消息不能裁剪——需要在 Context 模块内独立演进。如果 Runtime 内嵌了 Prompt 拼接,修改压缩策略就需要改推理循环,两者耦合。 **第二步:调用模型。** Runtime 把组装好的 Prompt 交给 Provider。Provider 是模型厂商的抽象层——它唯一的职责就是把不同厂商的 API 归一化为两个操作:估算 Token 数、发起流式推理。 @@ -345,7 +345,6 @@ flowchart LR **上下文裁剪(Compact)。** 当消息历史的 Token 数接近模型窗口上限时,Context 模块自动触发压缩: -- **MicroCompact**:移除较早的工具调用细节,保留摘要。优先裁剪输出最长的 Tool Result。 - **Full Compact**:调用 LLM 对整段历史生成摘要,替换原始消息列表。旧消息删除和新摘要插入在同一个 SQLite 事务中完成,保证原子性。 数据回流发生在两处:**工具结果回灌**(Tool Result 写入 Session Messages,供下一轮推理使用)和 **Compact 结果回写**(压缩后的摘要替换原始历史)。 diff --git a/docs/context-compact.md b/docs/context-compact.md index c93d7c57b..62257bf52 100644 --- a/docs/context-compact.md +++ b/docs/context-compact.md @@ -19,10 +19,8 @@ context: compact: manual_strategy: keep_recent manual_keep_recent_messages: 10 - micro_compact_retained_tool_spans: 6 read_time_max_message_spans: 24 max_summary_chars: 1200 - micro_compact_disabled: false budget: prompt_budget: 0 reserve_tokens: 13000 @@ -38,12 +36,8 @@ context: 在 `keep_recent` 模式下保留的最近消息数,并按 tool call / tool result 的原子块整体保留。 - `read_time_max_message_spans` 控制 `context.Builder` 读时 trim 可保留的 message span 上限。 -- `micro_compact_retained_tool_spans` - 控制 read-time micro compact 默认保留原始内容的最近可压缩工具块数量。 - `max_summary_chars` 控制 compact summary 的最大字符数。 -- `micro_compact_disabled` - 控制是否关闭默认启用的 read-time micro compact。 ### `context.budget` diff --git a/docs/examples/hooks.yaml b/docs/examples/hooks.yaml index 7a89625cf..ad90911d8 100644 --- a/docs/examples/hooks.yaml +++ b/docs/examples/hooks.yaml @@ -17,8 +17,9 @@ hooks: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: require-readme-before-final diff --git a/docs/examples/user-hooks-config.yaml b/docs/examples/user-hooks-config.yaml index 055b35626..0228cc5f9 100644 --- a/docs/examples/user-hooks-config.yaml +++ b/docs/examples/user-hooks-config.yaml @@ -25,8 +25,9 @@ runtime: kind: builtin mode: sync handler: warn_on_tool_call + match: + tool_name: ["bash"] params: - tool_names: ["bash"] message: "执行 bash 前请确认命令不会破坏工作区。" - id: user-http-observe diff --git a/docs/gateway-error-catalog.md b/docs/gateway-error-catalog.md index 7c7e482f4..1c2d33f0f 100644 --- a/docs/gateway-error-catalog.md +++ b/docs/gateway-error-catalog.md @@ -10,6 +10,7 @@ | `missing_required_field` | 200 | -32602 | 缺失必填字段(如 `params.session_id`、`params.request_id`、`payload.run_id`)。 | 直接失败,补齐字段。 | | `unsupported_action` | 200 | -32601 | 方法不存在或当前版本未实现。 | 降级到兼容方法,或提示版本不支持。 | | `internal_error` | 200 | -32603 | 网关内部异常、运行时不可用、不可归类的执行失败。 | 可短暂重试;持续失败需告警。 | +| `max_turn_exceeded` | 200 | -32602 | Runtime 达到 `runtime.max_turns` 后受控停止;异步 `gateway.run` 会通过 `run_error.stop_reason=max_turn_exceeded` 透传。 | 提示用户可继续发送消息、拆分任务或调高 `runtime.max_turns`,不要按网关内部错误告警。 | | `timeout` | 200 | -32603 | Gateway 调用 runtime 超过操作超时窗口。 | 可重试并增加客户端超时预算;必要时调用 `gateway.cancel`。 | | `unauthorized` | 401 | -32602 | 未提供有效 token 或连接未完成认证。 | 刷新凭据并重新认证,不建议盲重试。 | | `access_denied` | 403 | -32602 | 已认证但 ACL/主体权限不允许当前动作或资源访问。 | 直接失败,提示权限不足。 | diff --git a/docs/gateway-rpc-api.md b/docs/gateway-rpc-api.md index 876534147..c3013e963 100644 --- a/docs/gateway-rpc-api.md +++ b/docs/gateway-rpc-api.md @@ -155,7 +155,8 @@ type BindStreamParams struct { ```go type RunInputMedia struct { - URI string `json:"uri"` + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -175,6 +176,12 @@ type RunParams struct { } ``` +- 多模态图片约束: + - `type=image` 时 `media.mime_type` 必填。 + - `media.uri` 与 `media.asset_id` 必须二选一,不能同时为空或同时提供。 + - `media.uri` 仅用于后端可读取的本地路径;Web 浏览器上传图片应先通过 `POST /api/session-assets` 保存,再在 `gateway.run` 中使用 `media.asset_id` 引用。 + - `asset_id` 必须属于当前 `session_id`,不存在或跨 session 引用会在 runtime 输入准备阶段失败。 + - Response Schema: - Success(受理即返回): @@ -223,6 +230,49 @@ type RunParams struct { --- +## HTTP API: session assets + +浏览器图片上传不应把本地伪路径传给 Runtime。Web 客户端需要在发送前先创建或确认 `session_id`,再通过受鉴权保护的 HTTP API 保存图片,最后在 `gateway.run.input_parts[].media.asset_id` 中引用。 + +### POST /api/session-assets + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;单工作区或旧客户端可省略并回落到默认工作区。 +- Content-Type: `multipart/form-data` +- Fields: + - `session_id`: 目标会话 ID,必填。 + - `file`: 图片文件,必填。 +- Server-side validation: + - 仅接受 `image/png`、`image/jpeg`、`image/webp`。 + - MIME 以服务端文件头检测结果为准,不信任浏览器声明。 + - 空文件返回 `400`。 + - 超过 `MaxSessionAssetBytes` 返回 `413`。 + - 非图片或不支持类型返回 `415`。 + - 未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 + - 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- Response: + +```json +{ + "session_id": "sess-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +### GET /api/session-assets/{session_id}/{asset_id} + +- Auth Required: Yes(`Authorization: Bearer `) +- Headers: + - `X-NeoCode-Workspace-Hash`: 当前工作区哈希。多工作区 Web 客户端必须发送;省略时回落到默认工作区。 +- 返回图片二进制,`Content-Type` 为保存时确认的 MIME。 +- 用于历史消息缩略图按需读取。 +- 工作区不存在返回 `404 workspace not found`;不存在或不可见的 asset 返回 `404 asset not found`。 + +--- + ## Method: gateway.compact - Stability: Stable @@ -421,6 +471,41 @@ type ResolvePermissionParams struct { --- +## Method: gateway.approvePlan + +- Stability: Stable +- Auth Required: Yes +- Request Schema: + +```go +type ApprovePlanParams struct { + SessionID string `json:"session_id"` // MUST + PlanID string `json:"plan_id"` // MUST + Revision int `json:"revision"` // MUST > 0 +} +``` + +- Response Schema: + +```json +{ + "type": "ack", + "action": "approve_plan", + "session_id": "session-1", + "payload": { + "plan_id": "plan-1", + "revision": 2, + "status": "approved" + } +} +``` + +- Semantics: + - 仅批准当前会话中匹配 `plan_id + revision` 的 `draft` 计划。 + - 成功后客户端可再调用 `gateway.run({ "mode": "build" })` 执行已批准计划。 + +--- + ## Method: gateway.userQuestionAnswer - Stability: Beta diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 229db91cc..c1b08047b 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -61,10 +61,8 @@ context: compact: manual_strategy: keep_recent manual_keep_recent_messages: 10 - micro_compact_retained_tool_spans: 6 read_time_max_message_spans: 24 max_summary_chars: 1200 - micro_compact_disabled: false budget: prompt_budget: 0 reserve_tokens: 13000 @@ -89,10 +87,8 @@ context: |------|------| | `context.compact.manual_strategy` | `/compact` 手动压缩策略,支持 `keep_recent` / `full_replace` | | `context.compact.manual_keep_recent_messages` | `keep_recent` 下保留的最近消息数 | -| `context.compact.micro_compact_retained_tool_spans` | read-time micro compact 默认保留原始内容的最近工具块数量 | | `context.compact.read_time_max_message_spans` | context 构建时保留的 message span 上限 | | `context.compact.max_summary_chars` | compact summary 最大字符数 | -| `context.compact.micro_compact_disabled` | 是否关闭默认启用的 micro compact | ### `context.budget` diff --git a/docs/reference/gateway-error-catalog.md b/docs/reference/gateway-error-catalog.md index 7967c374d..c9a6712f1 100644 --- a/docs/reference/gateway-error-catalog.md +++ b/docs/reference/gateway-error-catalog.md @@ -2,7 +2,7 @@ 本文档用于第三方客户端实现统一异常处理策略,覆盖 Gateway 稳定错误码集合: -`invalid_frame`、`invalid_action`、`invalid_multimodal_payload`、`missing_required_field`、`unsupported_action`、`internal_error`、`timeout`、`unauthorized`、`access_denied`、`resource_not_found`。 +`invalid_frame`、`invalid_action`、`invalid_multimodal_payload`、`missing_required_field`、`unsupported_action`、`internal_error`、`max_turn_exceeded`、`timeout`、`unauthorized`、`access_denied`、`resource_not_found`。 ## 1. 错误码对照表 @@ -10,10 +10,11 @@ | --- | --- | --- | --- | --- | --- | | `invalid_frame` | `200` | `-32700` / `-32600` / `-32602` | 请求帧结构或编码不合法。包括 JSON 解析失败、请求体包含多余 JSON 值、`id/jsonrpc` 非法、`params` 严格解码失败。 | 非法 JSON;`id` 为 `null`;`params` 含未知字段。 | 不要直接重试,先修复请求构造器。 | | `invalid_action` | `200` | `-32602` | 动作参数值非法,但方法本身存在。 | `params.channel` 不在 `all/ipc/ws/sse`;`params.decision` 非 `allow_once/allow_session/reject`。 | 视为调用方输入错误,修正参数后再发。 | -| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.uri` 或 `media.mime_type`;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | +| `invalid_multimodal_payload` | `200` | `-32602` | `gateway.run` 的 `input_parts` 结构或字段不满足契约。 | `image` 分片缺少 `media.mime_type`,或 `media.uri` / `media.asset_id` 未满足二选一;`text` 分片文本为空。 | 校验输入分片后重试,不做盲重试。 | | `missing_required_field` | `200` | `-32600` / `-32602` | 缺失必填字段。请求层字段缺失多映射为 `-32600`,方法参数层字段缺失多映射为 `-32602`。 | 缺失 `id`;缺失 `params`;`cancel` 缺失 `run_id`。 | 调整参数补齐必填项再重试。 | | `unsupported_action` | `200` | `-32601` | 方法未注册或不被网关识别。 | 调用不存在的方法名。 | 客户端按能力探测降级,或升级服务端版本。 | | `internal_error` | `200` | `-32603` | 网关内部异常或未分类下游异常。 | 结果编码失败;runtime port 不可用;未知运行时错误。 | 采用指数退避重试;持续失败时告警。 | +| `max_turn_exceeded` | `200` | `-32602` | Runtime 达到 `runtime.max_turns` 后受控停止。 | 异步 `gateway.run` 通过 `run_error` 返回 `stop_reason=max_turn_exceeded`。 | 提示用户继续发送消息、拆分任务或调高 `runtime.max_turns`;不要按网关内部错误告警。 | | `timeout` | `200` | `-32603` | 网关调用 runtime 超时(`context.DeadlineExceeded`)。 | `run/compact/cancel/loadSession/resolvePermission` 下游调用超时。 | 可重试且建议带幂等键(如固定 `run_id`)。 | | `unauthorized` | `401`(仅 /rpc) | `-32602` | 请求未通过认证。 | 未携带 token;token 非法;连接未先 `authenticate`。 | 先刷新凭证并重新认证,认证成功后再发业务请求。 | | `access_denied` | `403`(仅 /rpc) | `-32602` | 已认证但不具备该方法或资源权限。 | ACL 拒绝当前来源调用该方法;runtime 返回 access denied。 | 终止当前请求并提示授权不足,不要盲重试。 | diff --git a/docs/reference/gateway-rpc-api.md b/docs/reference/gateway-rpc-api.md index 4ec01eec0..82dad5784 100644 --- a/docs/reference/gateway-rpc-api.md +++ b/docs/reference/gateway-rpc-api.md @@ -306,6 +306,13 @@ type RunParams struct { Mode string `json:"mode,omitempty"` // Agent 工作模式:build|plan,可选,默认沿用 session 当前 mode } +type RunInputMedia struct { + URI string `json:"uri,omitempty"` + AssetID string `json:"asset_id,omitempty"` + MimeType string `json:"mime_type"` + FileName string `json:"file_name,omitempty"` +} + type RunInputPart struct { Type string `json:"type"` // text|image Text string `json:"text,omitempty"` // text MUST @@ -318,7 +325,7 @@ type RunInputPart struct { 1. `input_text` 与 `input_parts` 至少一项非空。 2. `input_parts` 中: 1. `type=text` 时 `text` `MUST` 非空。 -2. `type=image` 时 `media.uri` 与 `media.mime_type` `MUST` 非空。 +2. `type=image` 时 `media.mime_type` `MUST` 非空,`media.uri` 与 `media.asset_id` `MUST` 二选一且不能同时提供。Web 上传图片应先调用 `POST /api/session-assets`,再在 `gateway.run` 中用 `asset_id` 引用。 3. 未知字段会因严格解码触发 `invalid_frame`。 4. `run_id` 归一化顺序为:显式 `run_id` > `request_id` > 网关生成 `run_`。 5. `mode` 可选值为 `"build"` 或 `"plan"`,为空时默认沿用 session 当前 mode(新会话默认为 `"build"`)。切换 mode 后,后端会更新 session 并影响后续运行的工具可用性和 prompt 策略。 @@ -397,6 +404,37 @@ sequenceDiagram G-->>C: ack(cancel) ``` +### HTTP session asset API + +浏览器图片上传使用 HTTP API,不通过 JSON-RPC 传输文件内容。客户端发送图片前需要先拥有有效 `session_id`(新会话可先调用 `gateway.createSession`)。 + +`POST /api/session-assets` + +- Auth Required: `Yes`,使用 `Authorization: Bearer `。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送,省略时回落到默认工作区。 +- Content-Type: `multipart/form-data`。 +- 字段:`session_id`(必填)、`file`(必填)。 +- 仅接受 PNG/JPEG/WebP;服务端按文件头检测 MIME,不信任浏览器声明。 +- 空文件返回 `400`,超出 `MaxSessionAssetBytes` 返回 `413`,不支持 MIME 返回 `415`,未认证返回 `401`,Origin/CORS 或 ACL 拒绝返回 `403`。 +- 工作区不存在返回 `404 workspace not found`;目标 session 不在该工作区返回 `404 session not found`。 +- 成功返回: + +```json +{ + "session_id": "session-1", + "asset_id": "asset-1", + "mime_type": "image/png", + "size": 1024 +} +``` + +`GET /api/session-assets/{session_id}/{asset_id}` + +- Auth Required: `Yes`。 +- Headers: `X-NeoCode-Workspace-Hash` 携带当前工作区哈希;多工作区 Web 客户端必须发送。 +- 返回图片二进制,用于历史消息缩略图。 +- 工作区不存在返回 `404 workspace not found`;不存在、跨 session 或不可见的 asset 返回 `404 asset not found`。 + Observation: 1. `gateway_requests_total{method="gateway.run",status="ok|error"}`。 @@ -778,7 +816,44 @@ Observation: --- -## 15. wake.openUrl +## 15. gateway.approvePlan + +Method: `gateway.approvePlan` +Stability: `Stable` +Auth Required: `Yes` + +Request Schema: + +```go +type ApprovePlanParams struct { + SessionID string `json:"session_id"` // MUST + PlanID string `json:"plan_id"` // MUST + Revision int `json:"revision"` // MUST > 0 +} +``` + +Response Schema: + +```json +{ + "type": "ack", + "action": "approve_plan", + "session_id": "session-1", + "payload": { + "plan_id": "plan-1", + "revision": 2, + "status": "approved" + } +} +``` + +Semantics: +1. Only the current session plan matching `plan_id + revision` and `draft` status can be approved. +2. After success, clients can call `gateway.run` with `mode: "build"` to execute the approved plan. + +--- + +## 16. wake.openUrl Method: `wake.openUrl` Stability: `Experimental` @@ -828,7 +903,7 @@ Observation: --- -## 16. gateway.event(服务端通知) +## 17. gateway.event(服务端通知) Method: `gateway.event` Stability: `Stable` diff --git a/docs/roadmap.md b/docs/roadmap.md index fe6084393..f27a9d1f1 100644 --- a/docs/roadmap.md +++ b/docs/roadmap.md @@ -11,7 +11,6 @@ | **传输层全 HTTP 化** | 当前 `transport/` 中残留的 Unix socket / Named pipe 逻辑增加了双平台代码路径和维护负担。统一到 HTTP JSON-RPC 后:第三方客户端接入更简单(只需要发 HTTP POST,不需要理解 Unix socket 地址规则)、Windows 和 Linux/macOS 的客户端连接逻辑完全一致 | 高——正在进行 | | **Gateway 大文件拆分** | `bootstrap.go` 超 1600 行,包含帧路由、认证、session CRUD、RPC 处理、流绑定等所有 Gateway 逻辑。5 人团队每人负责不同模块,但 Gateway 的改动集中在同一大文件中 → 持续的合并冲突。按功能域拆分为 `auth_handler.go`、`session_handler.go`、`stream_handler.go` 后,各自改自己的文件 | 高——直接影响并行开发效率 | | **Runner 工具并行执行** | Runner 当前串行处理 Gateway 下发的工具请求。在"手机飞书下指令 → 工位 Runner 执行"场景中,模型经常一次产出多个独立的 tool call(如同时读 3 个文件),串行执行导致不必要的延迟。改为并行执行可显著改善远程场景的响应体验 | 中——核心差异化场景的性能瓶颈 | -| **Compact 配置收敛** | MicroCompact 的 `MicroCompactConfig`、Full Compact 的 `CompactConfig`、预算阈值在 `RuntimeConfig` 中分散定义。调整上下文压缩策略时需要理解三个不同的配置入口,容易产生不一致的配置。收敛为单一 `CompactPolicy` 结构体 | 中——降低调优门槛 | ### 17.2 中期(巩固和放大现有差异化优势) diff --git a/docs/runtime-hooks-design.md b/docs/runtime-hooks-design.md index f86141349..f4444a0f2 100644 --- a/docs/runtime-hooks-design.md +++ b/docs/runtime-hooks-design.md @@ -13,10 +13,11 @@ - P4:生命周期点位扩展(permission/session/compact/subagent)+ 点位能力矩阵 - P5:internal hooks 支持 `async/async_rewake` + run 内存通知队列(ephemeral 注入) - P6-lite:user `http/observe` hooks(仅观测回调) +- P6:user/repo `command` hooks(stdin/stdout JSON 协议) 当前未实现能力: -- command/prompt/agent hooks(P6) +- prompt/agent hooks(P6) ## P2 user hooks 边界 @@ -26,13 +27,18 @@ P2 仅支持: - `kind=builtin` - `mode=sync` - 挂载点:与 `HookPointCapability` 中 `UserAllowed=true` 的点位一致,当前包括: - `before_tool_call`、`after_tool_result`、`before_completion_decision`、`after_tool_failure`、 + `before_tool_call`、`after_tool_result`、`before_completion_decision`、`accept_gate`、`after_tool_failure`、 `session_start`、`session_end`、`user_prompt_submit`、`post_compact`、`subagent_stop` - handler:`require_file_exists`、`warn_on_tool_call`、`add_context_note` +- `match`:统一 matcher DSL(字段间 AND、同字段多值 OR),支持: + - `tool_name`:精确匹配(`string` 或 `[]string`) + - `tool_name_regex`:正则匹配(`string` 或 `[]string`,单条最长 256) + - `arguments_contains`:参数预览包含匹配(`[]string`) - `kind=http + mode=observe`:允许发送 HTTP 观测回调(不支持 block) - `http observe` 默认不携带 metadata(`include_metadata=false`);即使显式开启也会剥离 `result_content_preview`、`execution_error` - `http observe` 回调端点仅允许 loopback 地址(`localhost` / `127.0.0.1` / `::1`),避免误配为公网外发 -- external kinds 中 `command/prompt/agent` 在 P6-lite 阶段显式拒绝,不会半生效 +- `kind=command + mode=sync`:允许执行外部命令,通过 stdin/stdout JSON 协议通信(详见下方 P6 章节) +- external kinds 中 `prompt/agent` 仍显式拒绝 当前(P3)明确不支持: @@ -71,6 +77,7 @@ user/repo hook 接收的 `HookContext` 经过白名单裁剪,仅保留最小 - `run_id` / `session_id` - `point` / `tool_call_id` / `tool_name` +- `tool_arguments_preview`(脱敏+截断后的参数预览) - `is_error` / `error_class` - `result_content_preview` / `result_metadata_present` - `execution_error` @@ -91,6 +98,7 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 - `before_tool_call` - `after_tool_result` - `before_completion_decision` +- `accept_gate` - `before_permission_decision` - `after_tool_failure` - `session_start` @@ -104,8 +112,24 @@ runtime 内置 `HookPointCapability` 作为唯一真源,定义每个点位是 约束规则: - `CanBlock=false` 的点位,hook 返回 `block` 会自动降级为观测结果,不中断主链。 -- `CanUpdateInput` 仅作为能力建模;当前阶段不开放输入改写通道。 +- `CanUpdateInput` 在 `user_prompt_submit` 点位已开放:command hook 可通过 stdout JSON 的 `update_input` 字段改写用户输入。 - `UserAllowed=false` 的点位拒绝 user/repo 挂载(配置 fail-fast)。 +- matcher 字段会按点位能力矩阵做 fail-fast:不支持的维度会在配置加载阶段直接报错。 + +### matcher 点位维度矩阵(#684) + +| point | tool_name | tool_name_regex | arguments_contains | +|---|---|---|---| +| `before_tool_call` | ✅ | ✅ | ✅ | +| `after_tool_result` | ✅ | ✅ | ❌ | +| `after_tool_failure` | ✅ | ✅ | ✅ | +| `before_permission_decision` | ✅ | ✅ | ❌ | +| 其他点位 | ❌ | ❌ | ❌ | + +说明: + +- `arguments_contains` 基于 `tool_arguments_preview` 字段匹配,不读取 `tool_arguments` 原文。 +- `warn_on_tool_call` 当前要求显式配置 `match`;旧参数 `params.tool_name/tool_names` 不再承担匹配语义。 ### trust gate @@ -134,6 +158,159 @@ trust store 固定路径: - 绝对路径必须位于 workdir 内 - symlink 路径会进行 realpath 校验,禁止绕过 +## P6 command hooks + +`kind=command` 允许 user/repo scope 通过外部可执行脚本参与 hook 链。 + +### stdin 协议 + +外部命令通过 stdin 接收单行 JSON: + +```json +{ + "payload_version": "1", + "hook_id": "my-hook", + "point": "before_tool_call", + "run_id": "run_abc123", + "session_id": "sess_abc123", + "metadata": { + "tool_name": "bash", + "workdir": "/path/to/workspace" + } +} +``` + +- `payload_version`:协议版本号,当前固定 `"1"`,变更 stdin 结构时递增 +- `hook_id`:hook 配置中的 `id` +- `point`:触发点位名称 +- `metadata`:经白名单裁剪后的上下文字段(与 builtin/http hook 相同的 allowlist) + +### stdout 协议 + +外部命令通过 stdout 返回单行 JSON: + +```json +{ + "status": "pass", + "message": "optional message", + "update_input": {"text": "rewritten prompt"}, + "annotations": ["note1", "note2"] +} +``` + +- `status`:必填,`pass` / `block` / `failed` +- `message`:可选,进入 hook event 和 annotation buffer +- `update_input`:仅 `CanUpdateInput=true` 的点位(当前仅 `user_prompt_submit`)允许;格式 `{"text": "..."}` 替换用户输入文本 +- `annotations`:字符串数组,进入 runtime annotation buffer + +### stdout 退化模式 + +如果 stdout 不是合法 JSON,handler 退化为 exit code 模式: + +- exit 0 → `pass` +- exit 1 或 2 → `block` +- 其他 → `failed` + +原始 stdout 文本作为 `message`。此模式兼容简单脚本(如 `echo "ok"; exit 0`)。 + +### 执行模式 + +#### argv 模式(默认) + +`params.command` 为字符串数组,直接 exec 不经 shell: + +```yaml +kind: command +params: + command: + - python3 + - /path/to/hook.py +``` + +#### shell 模式 + +`params.command` 为字符串且 `params.shell: true`,通过 `sh -c`(Unix)/ `powershell -Command`(Windows)执行: + +```yaml +kind: command +params: + command: "python3 /path/to/hook.py" + shell: true +``` + +单字符串 `params.command` 不设置 `params.shell: true` 会触发配置校验错误。 + +### 环境变量 + +命令进程仅注入以下环境变量,不继承宿主环境: + +| 变量 | 值 | +|------|------| +| `NEOCODE_HOOK_HOOK_ID` | hook 的 `id` | +| `NEOCODE_HOOK_POINT` | 触发点位(如 `before_tool_call`) | +| `NEOCODE_HOOK_PAYLOAD_VERSION` | `"1"` | + +Windows 额外注入 `SystemRoot`、`SystemDrive`、`USERPROFILE`(从宿主环境读取),以确保 TLS 证书加载和运行时基础功能正常工作。 + +### 执行约束 + +- workdir = 当前 run 的 workspace(`cmd.Dir = workdir`) +- 超时 = hook 配置的 `timeout_sec`(默认 2s) +- 并发限制 = executor 的 `max_in_flight`(默认 128) +- repo scope command hook 受 trust gate 保护 +- stdout 大小限制 = 1 MiB;超出视为 `failed` + +### stderr 处理 + +外部命令的 stderr 与 stdout 分离捕获。stderr 不会混入 `message` 字段,仅在命令执行失败(非零 exit code)且 stdout 无可用 message 时,stderr 内容才作为 fallback 追加到结果中。此设计确保 hook 协议输出(stdout JSON)不受调试输出(stderr)干扰。 + +### stdin 字段说明 + +- `run_id` / `session_id` 同时出现在 payload 顶层和 `metadata` 中。**顶层字段为权威来源**,`metadata` 中的同名字段为冗余副本(与 builtin/http hook 的 metadata allowlist 一致)。外部脚本应优先读取顶层字段。 +- `payload_version` 当前固定为 `"1"`,变更 stdin 结构时递增。 + +### update_input 与 block 交互 + +当 hook 返回 `status: "block"` 时,`update_input` 不会被应用。阻断优先于输入改写——hook 链在检测到 block 后立即终止,不进入 `applyCommandHookUpdateInput` 逻辑。 + +### 安全:exit code 优先于 JSON status + +当命令以非零 exit code 退出时,stdout 中 JSON 声称的 `status` 字段被忽略。exit code 的映射优先: + +- exit 1/2 → `block` +- 其他非零 → `failed` + +此规则防止恶意脚本通过 `{"status":"pass"}` 掩盖实际失败。JSON 中的 `message` 和 `annotations` 仍会被提取(如果 stdout 是合法 JSON)。 + +### 示例 + +#### Python + +```python +#!/usr/bin/env python3 +import json, sys + +payload = json.loads(sys.stdin.readline()) +if payload["metadata"].get("tool_name") == "bash": + json.dump({"status": "block", "message": "bash not allowed"}, sys.stdout) +else: + json.dump({"status": "pass"}, sys.stdout) +print() +``` + +#### Bash + +```bash +#!/bin/bash +read -r line +tool=$(echo "$line" | jq -r '.metadata.tool_name // empty') +if [ "$tool" = "rm" ]; then + echo '{"status":"block","message":"rm is blocked"}' +else + echo '{"status":"pass"}' +fi +``` + ## 可观测性 runtime 会透传 hooks 生命周期事件: @@ -172,3 +349,75 @@ user/repo hook 的 `message` 会进入 runtime 的 annotation buffer(运行态 - `fail_closed` -> `fail_closed` 其中 `warn_only/fail_open` 不阻断主链,仅记录失败;`fail_closed` 触发阻断。 + +## Runtime 事件契约 + +runtime 事件在三端之间传递,任一端遗漏不会触发编译错误,仅在运行时表现为"事件丢失"或"未知事件被透传"。契约检查器通过 CI 测试强制三端一致性。 + +### 事件流转路径 + +```text +runtime (events.go) → gateway protocol encode → gateway_stream_client decode → TUI update handler consume +``` + +### 新增 runtime event 三步清单 + +当新增一个 runtime event 时,必须完成以下三步: + +**Step 1:定义事件常量与 payload** + +在 `internal/runtime/events.go`(或 `events_subagent.go`)中添加 `Event*` 常量和对应的 payload 结构体。 + +```go +// events.go +const EventMyNewEvent EventType = "my_new_event" + +type MyNewEventPayload struct { + Field string `json:"field"` +} +``` + +**Step 2:添加 gateway decode 分支** + +在 `internal/tui/services/gateway_stream_client.go` 的 `restoreRuntimePayload` 函数中添加对应的 case 分支: + +```go +case EventMyNewEvent: + return decodeRuntimePayload[MyNewEventPayload](payload) +``` + +同时在 `internal/tui/services/runtime_contract.go` 中: +- 添加 `EventMyNewEvent` 常量定义 +- 在 `contractRegistry` 中注册,设置 `RequireConsumer` 为 `true`(需要 TUI 消费)或 `false`(透传安全) + +```go +// runtime_contract.go +const EventMyNewEvent EventType = "my_new_event" + +// contractRegistry 中添加: +EventMyNewEvent: {RequireConsumer: true}, +``` + +**Step 3:添加 TUI 消费者** + +在 `internal/tui/core/app/update.go` 的 `runtimeEventHandlerRegistry` 中添加对应 handler: + +```go +// update.go - runtimeEventHandlerRegistry 中添加: +tuiservices.EventMyNewEvent: runtimeEventMyNewEventHandler, +``` + +### CI 契约检查 + +以下测试用例在 CI 中强制执行事件契约一致性: + +- `TestRuntimeEventContractConsistency`:扫描 runtime 事件常量,未注册且不在 `legacyPassthroughEvents` 中的事件会导致 CI 失败 +- `TestGatewayDecodeBranchConsistency`:验证 gateway decode 分支中的事件都在 contractRegistry 中注册 +- `TestRequireConsumerMustHaveDecodeBranch`:验证 `RequireConsumer=true` 的事件必须有 gateway decode 分支 +- `TestRequireConsumerMustHaveTUIConsumer`:验证 `RequireConsumer=true` 的事件必须在 `runtimeEventHandlerRegistry` 中有 handler + +若 CI 失败,检查以上三步是否遗漏。 + +### 遗留透传事件 + +`legacyPassthroughEvents` 是已知的遗留透传事件允许列表,这些事件在 contractRegistry 建立之前已存在,允许不注册。新增的 runtime Event* 常量必须显式注册到 contractRegistry,否则 CI 失败。 diff --git a/docs/runtime-provider-event-flow.md b/docs/runtime-provider-event-flow.md index 5b9a12e0b..9d98fcaa2 100644 --- a/docs/runtime-provider-event-flow.md +++ b/docs/runtime-provider-event-flow.md @@ -88,7 +88,7 @@ runtime 不再消费旧的 builder 压缩建议,而是使用冻结快照上的 - 组装 `system prompt` - 读取 `AGENTS.md` - 注入 `Task State` / `Todo State` / `Skills` / `Memo` -- 执行 read-time trim 和 micro compact +- 执行 read-time trim - 输出最终 `SystemPrompt` 与消息列表 `context.Builder` 不再负责: diff --git a/docs/tech-debt.md b/docs/tech-debt.md index 4ace9ad13..31262d4dc 100644 --- a/docs/tech-debt.md +++ b/docs/tech-debt.md @@ -26,7 +26,6 @@ |--------|------|------|-------------| | **底层传输层 IPC 残留** | `internal/gateway/transport/` — Unix domain socket / Named pipe | 客户端连接路径复杂(需判断平台选 socket 类型),迁移到全 HTTP 后可消除 | 短期——已在迁移计划中 | | **`runtime/run.go` 单文件过长** | ReAct 主循环逻辑集中在 `run.go` (~400 行) 和 `runtime.go` (~540 行) | 新成员理解核心循环需要较长时间;修改风险集中在少数大文件中 | 中期——可按阶段拆分(pre-processing / loop body / termination) | -| **Compact 策略配置分散** | MicroCompact 配置在 `MicroCompactConfig`,Full Compact 在 `CompactConfig`,部分阈值在 `RuntimeConfig` | 调整上下文管理策略需要理解三处配置 | 中期——收敛为统一的 `CompactPolicy` 结构体 | | **Gateway Bootstrap 单文件** | `bootstrap.go` 超过 1600 行,包含帧路由、认证、session CRUD、RPC 处理 | 单体文件难以定位和维护 | 中期——拆分为 `session_handler.go`、`rpc_handler.go`、`auth_handler.go` | | **Acceptance 测试耗时长** | `runtime/acceptance/` 的端到端测试依赖真实模型 API | CI 成本高、不稳定(网络波动导致 flaky) | 长期——增加录制/回放(VCR)模式,CI 中默认使用录制的 fixture | diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 41bea5652..45fcd3c17 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -179,14 +179,7 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime log.Printf("session cleanup warning: %v", err) } - // 注册内置工具的内容摘要器,使 micro-compact 在清理旧工具结果时保留关键上下文。 - tools.RegisterBuiltinSummarizers(toolRegistry) - - microCompactCfg := agentcontext.MicroCompactConfig{ - Policies: toolRegistry, - Summarizers: toolRegistry, - } - var contextBuilder agentcontext.Builder = agentcontext.NewConfiguredBuilder(microCompactCfg) + var contextBuilder agentcontext.Builder = agentcontext.NewConfiguredBuilder() var memoSvc *memo.Service if cfg.Memo.Enabled { memoStore := memo.NewFileStore(sharedDeps.ConfigManager.BaseDir(), cfg.Workdir) @@ -195,7 +188,7 @@ func BuildGatewayServerDeps(ctx context.Context, opts BootstrapOptions) (Runtime if invalidator, ok := memoSource.(interface{ InvalidateCache() }); ok { sourceInvl = invalidator.InvalidateCache } - contextBuilder = agentcontext.NewConfiguredBuilder(microCompactCfg, memoSource) + contextBuilder = agentcontext.NewConfiguredBuilder(memoSource) memoSvc = memo.NewService(memoStore, cfg.Memo, sourceInvl) toolRegistry.Register(memotool.NewRememberTool(memoSvc)) toolRegistry.Register(memotool.NewRecallTool(memoSvc)) @@ -457,11 +450,7 @@ func buildToolRegistry(cfg config.Config) (*tools.Registry, func() error, error) toolRegistry.Register(filesystem.NewGrep(cfg.Workdir)) toolRegistry.Register(filesystem.NewGlob(cfg.Workdir)) toolRegistry.Register(filesystem.NewEdit(cfg.Workdir)) - toolRegistry.Register(filesystem.NewMove(cfg.Workdir)) - toolRegistry.Register(filesystem.NewCopy(cfg.Workdir)) toolRegistry.Register(filesystem.NewDelete(cfg.Workdir)) - toolRegistry.Register(filesystem.NewCreateDir(cfg.Workdir)) - toolRegistry.Register(filesystem.NewRemoveDir(cfg.Workdir)) toolRegistry.Register(bash.New(cfg.Workdir, cfg.Shell, time.Duration(cfg.ToolTimeoutSec)*time.Second)) toolRegistry.Register(diagnosetool.New()) toolRegistry.Register(webfetch.New(webfetch.Config{ diff --git a/internal/app/bootstrap_test.go b/internal/app/bootstrap_test.go index a1af4c49d..11dc335c1 100644 --- a/internal/app/bootstrap_test.go +++ b/internal/app/bootstrap_test.go @@ -2184,9 +2184,6 @@ func (s *stubRemoteRuntimeForBootstrap) Close() error { func (s stubToolForBootstrap) Name() string { return s.name } func (s stubToolForBootstrap) Description() string { return "stub" } func (s stubToolForBootstrap) Schema() map[string]any { return map[string]any{"type": "object"} } -func (s stubToolForBootstrap) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} func (s stubToolForBootstrap) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { return tools.ToolResult{Name: s.name, Content: s.content}, nil } diff --git a/internal/checkpoint/per_edit_snapshot_test.go b/internal/checkpoint/per_edit_snapshot_test.go index 27358e13f..5c3627abf 100644 --- a/internal/checkpoint/per_edit_snapshot_test.go +++ b/internal/checkpoint/per_edit_snapshot_test.go @@ -883,72 +883,6 @@ func TestCapturePostDelete_DirectoryTreeRecovery(t *testing.T) { } } -func TestRestore_RemoveDirWithNestedFiles(t *testing.T) { - store, workdir := newTestStore(t) - dir := filepath.Join(workdir, "foo") - child := filepath.Join(dir, "bar.txt") - - // Turn 1: create tree. - writeWorkdirFile(t, workdir, "foo/bar.txt", "hello") - if _, err := store.CapturePreWrite(dir); err != nil { - t.Fatalf("capture dir t1: %v", err) - } - if _, err := store.CapturePreWrite(child); err != nil { - t.Fatalf("capture child t1: %v", err) - } - if _, err := store.Finalize("cp1"); err != nil { - t.Fatalf("finalize cp1: %v", err) - } - store.Reset() - - // Turn 2: remove tree with recursive pre-capture + post-delete. - if _, err := store.CapturePreWrite(dir); err != nil { - t.Fatalf("capture dir t2: %v", err) - } - if _, err := store.CapturePreWrite(child); err != nil { - t.Fatalf("capture child t2: %v", err) - } - if err := os.RemoveAll(dir); err != nil { - t.Fatalf("removeAll: %v", err) - } - if err := store.CapturePostDelete([]string{dir, child}); err != nil { - t.Fatalf("CapturePostDelete: %v", err) - } - if _, err := store.Finalize("cp2"); err != nil { - t.Fatalf("finalize cp2: %v", err) - } - store.Reset() - - // Turn 3: recreate tree with different content. - writeWorkdirFile(t, workdir, "foo/bar.txt", "world") - if _, err := store.CapturePreWrite(dir); err != nil { - t.Fatalf("capture dir t3: %v", err) - } - if _, err := store.CapturePreWrite(child); err != nil { - t.Fatalf("capture child t3: %v", err) - } - if _, err := store.Finalize("cp3"); err != nil { - t.Fatalf("finalize cp3: %v", err) - } - store.Reset() - - // Restore cp2: should delete the tree. - if err := store.Restore(context.Background(), "cp2", ""); err != nil { - t.Fatalf("restore cp2: %v", err) - } - if _, err := os.Stat(dir); !os.IsNotExist(err) { - t.Fatalf("expected dir absent after restore cp2, stat err=%v", err) - } - - // Restore cp1: should recreate the tree with original content. - if err := store.Restore(context.Background(), "cp1", ""); err != nil { - t.Fatalf("restore cp1: %v", err) - } - if got := mustReadFile(t, child); got != "hello" { - t.Fatalf("child want hello got %q", got) - } -} - func TestPerEditStoreHelperMethods(t *testing.T) { t.Run("availability and pending lifecycle", func(t *testing.T) { var nilStore *PerEditSnapshotStore diff --git a/internal/cli/cli_ux_test.go b/internal/cli/cli_ux_test.go index f125fcdd1..734b3de47 100644 --- a/internal/cli/cli_ux_test.go +++ b/internal/cli/cli_ux_test.go @@ -100,4 +100,3 @@ func TestLegacyFeishuAdapterCommandShowsMigrationHint(t *testing.T) { t.Fatalf("err = %v, want contains adapter feishu", err) } } - diff --git a/internal/cli/gateway_runtime_bridge.go b/internal/cli/gateway_runtime_bridge.go index 64b3c5762..9d77d9926 100644 --- a/internal/cli/gateway_runtime_bridge.go +++ b/internal/cli/gateway_runtime_bridge.go @@ -40,6 +40,10 @@ type runtimeRunCanceler interface { CancelRun(runID string) bool } +type sessionAssetDeleter interface { + DeleteAsset(ctx context.Context, sessionID string, assetID string) error +} + type runtimeSessionCreator interface { CreateSession(ctx context.Context, id string) (agentsession.Session, error) } @@ -314,6 +318,9 @@ func (b *gatewayRuntimePortBridge) Run(ctx context.Context, input gateway.RunInp return err } err := b.runtime.Submit(ctx, convertGatewayRunInput(input)) + if agentruntime.IsMaxTurnLimitError(err) { + return gateway.NewRuntimeMaxTurnExceededError(err.Error()) + } if err != nil && isRuntimeNotFoundError(err) { sessionID := strings.TrimSpace(input.SessionID) if sessionID == "" { @@ -326,7 +333,11 @@ func (b *gatewayRuntimePortBridge) Run(ctx context.Context, input gateway.RunInp if _, createErr := creator.CreateSession(ctx, sessionID); createErr != nil { return err } - return b.runtime.Submit(ctx, convertGatewayRunInput(input)) + retryErr := b.runtime.Submit(ctx, convertGatewayRunInput(input)) + if agentruntime.IsMaxTurnLimitError(retryErr) { + return gateway.NewRuntimeMaxTurnExceededError(retryErr.Error()) + } + return retryErr } return err } @@ -497,6 +508,37 @@ func (b *gatewayRuntimePortBridge) ResolvePermission(ctx context.Context, input }) } +// ApprovePlan 将网关计划批准请求转换为 runtime 当前计划批准输入。 +func (b *gatewayRuntimePortBridge) ApprovePlan( + ctx context.Context, + input gateway.ApprovePlanInput, +) (gateway.ApprovePlanResult, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.ApprovePlanResult{}, err + } + approver, ok := b.runtime.(agentruntime.PlanApprover) + if !ok { + return gateway.ApprovePlanResult{}, fmt.Errorf("gateway runtime bridge: runtime does not support plan approval") + } + sessionID := strings.TrimSpace(input.SessionID) + planID := strings.TrimSpace(input.PlanID) + if err := approver.ApproveCurrentPlan(ctx, agentruntime.ApproveCurrentPlanInput{ + SessionID: sessionID, + PlanID: planID, + Revision: input.Revision, + }); err != nil { + if agentruntime.IsPlanApprovalInvalidError(err) { + return gateway.ApprovePlanResult{}, fmt.Errorf("%w: %v", gateway.ErrRuntimeInvalidAction, err) + } + return gateway.ApprovePlanResult{}, err + } + return gateway.ApprovePlanResult{ + PlanID: planID, + Revision: input.Revision, + Status: "approved", + }, nil +} + // ResolveUserQuestion 将网关 ask_user 回答转发到 runtime。 func (b *gatewayRuntimePortBridge) ResolveUserQuestion(ctx context.Context, input gateway.UserQuestionAnswerInput) error { if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { @@ -659,6 +701,108 @@ func (b *gatewayRuntimePortBridge) CreateSession(ctx context.Context, input gate return strings.TrimSpace(session.ID), nil } +// SaveSessionAsset 将浏览器上传的附件保存到当前工作区的 session asset store。 +func (b *gatewayRuntimePortBridge) SaveSessionAsset( + ctx context.Context, + input gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.SessionAssetMeta{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + if sessionID == "" { + return gateway.SessionAssetMeta{}, gateway.ErrRuntimeResourceNotFound + } + if b.sessionStore == nil { + return gateway.SessionAssetMeta{}, fmt.Errorf("gateway runtime bridge: session store is unavailable") + } + loader, ok := b.sessionStore.(bridgeSessionLoader) + if !ok { + return gateway.SessionAssetMeta{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + if _, err := loader.LoadSession(ctx, sessionID); err != nil { + if isRuntimeNotFoundError(err) { + return gateway.SessionAssetMeta{}, gateway.ErrRuntimeResourceNotFound + } + return gateway.SessionAssetMeta{}, err + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.SessionAssetMeta{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + meta, err := assetStore.SaveAsset(ctx, sessionID, input.Reader, strings.TrimSpace(input.MimeType)) + if err != nil { + return gateway.SessionAssetMeta{}, err + } + return gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, nil +} + +// OpenSessionAsset 打开当前工作区的会话附件,供 Gateway HTTP 读取端点流式返回。 +func (b *gatewayRuntimePortBridge) OpenSessionAsset( + ctx context.Context, + input gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return gateway.OpenSessionAssetResult{}, err + } + sessionID := strings.TrimSpace(input.SessionID) + assetID := strings.TrimSpace(input.AssetID) + if sessionID == "" || assetID == "" { + return gateway.OpenSessionAssetResult{}, gateway.ErrRuntimeResourceNotFound + } + assetStore, ok := b.sessionStore.(agentsession.AssetStore) + if !ok || assetStore == nil { + return gateway.OpenSessionAssetResult{}, fmt.Errorf("gateway runtime bridge: session asset store is unavailable") + } + reader, meta, err := assetStore.Open(ctx, sessionID, assetID) + if err != nil { + if isRuntimeNotFoundError(err) || errors.Is(err, os.ErrNotExist) { + return gateway.OpenSessionAssetResult{}, gateway.ErrRuntimeResourceNotFound + } + return gateway.OpenSessionAssetResult{}, err + } + return gateway.OpenSessionAssetResult{ + Reader: reader, + Meta: gateway.SessionAssetMeta{ + SessionID: sessionID, + AssetID: strings.TrimSpace(meta.ID), + MimeType: strings.TrimSpace(meta.MimeType), + Size: meta.Size, + }, + }, nil +} + +// DeleteSessionAsset 删除当前工作区的会话附件,供 Web 在取消上传引用时释放服务端文件。 +func (b *gatewayRuntimePortBridge) DeleteSessionAsset(ctx context.Context, input gateway.DeleteSessionAssetInput) error { + if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { + return err + } + sessionID := strings.TrimSpace(input.SessionID) + assetID := strings.TrimSpace(input.AssetID) + if sessionID == "" || assetID == "" { + return gateway.ErrRuntimeResourceNotFound + } + if b.sessionStore == nil { + return fmt.Errorf("gateway runtime bridge: session store is unavailable") + } + deleter, ok := b.sessionStore.(sessionAssetDeleter) + if !ok || deleter == nil { + return fmt.Errorf("gateway runtime bridge: session asset store does not support delete") + } + if err := deleter.DeleteAsset(ctx, sessionID, assetID); err != nil { + if isRuntimeNotFoundError(err) { + return nil + } + return err + } + return nil +} + // DeleteSession 删除/归档指定会话。 func (b *gatewayRuntimePortBridge) DeleteSession(ctx context.Context, input gateway.DeleteSessionInput) (bool, error) { if err := b.ensureRuntimeAccess(input.SubjectID); err != nil { @@ -1646,11 +1790,13 @@ func convertGatewayRunInput(input gateway.RunInput) agentruntime.PrepareInput { continue } path := strings.TrimSpace(part.Media.URI) - if path == "" { + assetID := strings.TrimSpace(part.Media.AssetID) + if path == "" && assetID == "" { continue } images = append(images, agentruntime.UserImageInput{ Path: path, + AssetID: assetID, MimeType: strings.TrimSpace(part.Media.MimeType), }) } @@ -1829,6 +1975,7 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM convertedMessage := gateway.SessionMessage{ Role: strings.TrimSpace(message.Role), Content: renderSessionMessageContent(message.Parts), + Parts: convertProviderContentParts(message.Parts), ToolCallID: strings.TrimSpace(message.ToolCallID), IsError: message.IsError, } @@ -1847,18 +1994,116 @@ func convertSessionMessages(messages []providertypes.Message) []gateway.SessionM return converted } +// convertProviderContentParts 将 provider 通用内容分片转换为 Gateway 会话快照分片。 +func convertProviderContentParts(parts []providertypes.ContentPart) []gateway.InputPart { + if len(parts) == 0 { + return nil + } + converted := make([]gateway.InputPart, 0, len(parts)) + for _, part := range parts { + switch part.Kind { + case providertypes.ContentPartText: + if text := strings.TrimSpace(part.Text); text != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeText, + Text: text, + }) + } + case providertypes.ContentPartImage: + if part.Image == nil { + continue + } + switch part.Image.SourceType { + case providertypes.ImageSourceSessionAsset: + if part.Image.Asset == nil || strings.TrimSpace(part.Image.Asset.ID) == "" { + continue + } + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + AssetID: strings.TrimSpace(part.Image.Asset.ID), + MimeType: strings.TrimSpace(part.Image.Asset.MimeType), + }, + }) + case providertypes.ImageSourceRemote: + if url := strings.TrimSpace(part.Image.URL); url != "" { + converted = append(converted, gateway.InputPart{ + Type: gateway.InputPartTypeImage, + Media: &gateway.Media{ + URI: url, + }, + }) + } + } + } + } + return converted +} + +// convertRuntimePlanTodoItem 将 session 计划中的 legacy todo 项映射为 gateway 展示结构。 +func convertRuntimePlanTodoItem(item agentsession.TodoItem) gateway.PlanTodoItem { + required := false + if item.Required != nil { + required = *item.Required + } + return gateway.PlanTodoItem{ + ID: strings.TrimSpace(item.ID), + Content: strings.TrimSpace(item.Content), + Status: strings.TrimSpace(string(item.Status)), + Required: required, + Artifacts: append([]string(nil), item.Artifacts...), + FailureReason: strings.TrimSpace(item.FailureReason), + BlockedReason: strings.TrimSpace(string(item.BlockedReason)), + Revision: item.Revision, + } +} + +// convertRuntimePlanArtifact 将 runtime 当前计划快照映射为 gateway 公开契约。 +func convertRuntimePlanArtifact(plan *agentsession.PlanArtifact) *gateway.PlanArtifact { + if plan == nil { + return nil + } + converted := &gateway.PlanArtifact{ + ID: strings.TrimSpace(plan.ID), + Revision: plan.Revision, + Status: strings.TrimSpace(string(plan.Status)), + Spec: gateway.PlanSpec{ + Goal: strings.TrimSpace(plan.Spec.Goal), + Steps: append([]string(nil), plan.Spec.Steps...), + Constraints: append([]string(nil), plan.Spec.Constraints...), + OpenQuestions: append([]string(nil), plan.Spec.OpenQuestions...), + }, + Summary: gateway.PlanSummaryView{ + Goal: strings.TrimSpace(plan.Summary.Goal), + KeySteps: append([]string(nil), plan.Summary.KeySteps...), + Constraints: append([]string(nil), plan.Summary.Constraints...), + ActiveTodoIDs: append([]string(nil), plan.Summary.ActiveTodoIDs...), + }, + CreatedAt: plan.CreatedAt, + UpdatedAt: plan.UpdatedAt, + } + if len(plan.Spec.Todos) > 0 { + converted.Spec.Todos = make([]gateway.PlanTodoItem, 0, len(plan.Spec.Todos)) + for _, item := range plan.Spec.Todos { + converted.Spec.Todos = append(converted.Spec.Todos, convertRuntimePlanTodoItem(item)) + } + } + return converted +} + // convertRuntimeSessionToGatewaySession 将 runtime 会话结构映射为 gateway 契约返回值。 func convertRuntimeSessionToGatewaySession(session agentsession.Session) gateway.Session { return gateway.Session{ - ID: strings.TrimSpace(session.ID), - Title: strings.TrimSpace(session.Title), - CreatedAt: session.CreatedAt, - UpdatedAt: session.UpdatedAt, - Workdir: strings.TrimSpace(session.Workdir), - Provider: strings.TrimSpace(session.Provider), - Model: strings.TrimSpace(session.Model), - AgentMode: strings.TrimSpace(string(session.AgentMode)), - Messages: convertSessionMessages(session.Messages), + ID: strings.TrimSpace(session.ID), + Title: strings.TrimSpace(session.Title), + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Workdir: strings.TrimSpace(session.Workdir), + Provider: strings.TrimSpace(session.Provider), + Model: strings.TrimSpace(session.Model), + AgentMode: strings.TrimSpace(string(session.AgentMode)), + CurrentPlan: convertRuntimePlanArtifact(session.CurrentPlan), + Messages: convertSessionMessages(session.Messages), } } @@ -2446,6 +2691,7 @@ type manualModelPayload struct { } var _ gateway.RuntimePort = (*gatewayRuntimePortBridge)(nil) +var _ gateway.SessionAssetPort = (*gatewayRuntimePortBridge)(nil) func (b *gatewayRuntimePortBridge) ListCheckpoints(ctx context.Context, input gateway.ListCheckpointsInput) ([]gateway.CheckpointEntry, error) { cp, ok := b.runtime.(runtimeCheckpointer) diff --git a/internal/cli/gateway_runtime_bridge_test.go b/internal/cli/gateway_runtime_bridge_test.go index 117750fc1..01c6b92d7 100644 --- a/internal/cli/gateway_runtime_bridge_test.go +++ b/internal/cli/gateway_runtime_bridge_test.go @@ -1,10 +1,12 @@ package cli import ( + "bytes" "context" "encoding/json" "errors" "fmt" + "io" "os" "os/exec" "path/filepath" @@ -89,6 +91,12 @@ type runtimeStub struct { checkpointDiffErr error } +type runtimePlanApproverStub struct { + *runtimeStub + approveInput agentruntime.ApproveCurrentPlanInput + approveErr error +} + const testBridgeSubjectID = bridgeLocalSubjectID func (s *runtimeStub) Submit(_ context.Context, input agentruntime.PrepareInput) error { @@ -132,6 +140,14 @@ func (s *runtimeStub) ResolvePermission(_ context.Context, input agentruntime.Pe return s.permissionErr } +func (s *runtimePlanApproverStub) ApproveCurrentPlan( + _ context.Context, + input agentruntime.ApproveCurrentPlanInput, +) error { + s.approveInput = input + return s.approveErr +} + func (s *runtimeStub) ResolveUserQuestion(_ context.Context, input agentruntime.UserQuestionResolutionInput) error { s.userQuestionInput = input return s.userQuestionErr @@ -1075,6 +1091,101 @@ func TestGatewayRuntimePortBridgeListSessionTodosAndSnapshot(t *testing.T) { }) } +func TestGatewayRuntimePortBridgeApprovePlan(t *testing.T) { + runtimeSvc := &runtimePlanApproverStub{ + runtimeStub: &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + } + bridge, err := newGatewayRuntimePortBridge(context.Background(), runtimeSvc, testSessionStore) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + t.Cleanup(func() { _ = bridge.Close() }) + + result, err := bridge.ApprovePlan(context.Background(), gateway.ApprovePlanInput{ + SubjectID: testBridgeSubjectID, + SessionID: " session-1 ", + PlanID: " plan-1 ", + Revision: 3, + }) + if err != nil { + t.Fatalf("approve_plan: %v", err) + } + if runtimeSvc.approveInput.SessionID != "session-1" || runtimeSvc.approveInput.PlanID != "plan-1" || runtimeSvc.approveInput.Revision != 3 { + t.Fatalf("approve input = %#v, want trimmed session/plan revision", runtimeSvc.approveInput) + } + if result.PlanID != "plan-1" || result.Revision != 3 || result.Status != "approved" { + t.Fatalf("approve result = %#v, want approved plan-1 revision 3", result) + } +} + +func TestGatewayRuntimePortBridgeApprovePlanUnsupportedRuntime(t *testing.T) { + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + testSessionStore, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + t.Cleanup(func() { _ = bridge.Close() }) + + _, err = bridge.ApprovePlan(context.Background(), gateway.ApprovePlanInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + PlanID: "plan-1", + Revision: 1, + }) + if err == nil || !strings.Contains(err.Error(), "runtime does not support plan approval") { + t.Fatalf("approve_plan unsupported error = %v", err) + } +} + +func TestGatewayRuntimePortBridgeApprovePlanInvalidAction(t *testing.T) { + runtimeSvc := &runtimePlanApproverStub{ + runtimeStub: &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + approveErr: agentruntime.ErrPlanApprovalRevisionMismatch, + } + bridge, err := newGatewayRuntimePortBridge(context.Background(), runtimeSvc, testSessionStore) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + t.Cleanup(func() { _ = bridge.Close() }) + + _, err = bridge.ApprovePlan(context.Background(), gateway.ApprovePlanInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + PlanID: "plan-1", + Revision: 1, + }) + if !errors.Is(err, gateway.ErrRuntimeInvalidAction) { + t.Fatalf("approve_plan error = %v, want ErrRuntimeInvalidAction", err) + } +} + +func TestGatewayRuntimePortBridgeApprovePlanAccessDenied(t *testing.T) { + runtimeSvc := &runtimePlanApproverStub{ + runtimeStub: &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + } + bridge, err := newGatewayRuntimePortBridge(context.Background(), runtimeSvc, testSessionStore) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + t.Cleanup(func() { _ = bridge.Close() }) + + _, err = bridge.ApprovePlan(context.Background(), gateway.ApprovePlanInput{ + SubjectID: "other-subject", + SessionID: "session-1", + PlanID: "plan-1", + Revision: 1, + }) + if !errors.Is(err, gateway.ErrRuntimeAccessDenied) { + t.Fatalf("approve_plan error = %v, want ErrRuntimeAccessDenied", err) + } + if runtimeSvc.approveInput.SessionID != "" { + t.Fatalf("runtime approve should not be called, input = %#v", runtimeSvc.approveInput) + } +} + func TestGatewayRuntimePortBridgeLoadSessionNotFoundBranches(t *testing.T) { t.Parallel() @@ -1441,6 +1552,7 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { {Type: gateway.InputPartTypeImage, Media: nil}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " "}}, {Type: gateway.InputPartTypeImage, Media: &gateway.Media{URI: " /tmp/a.png ", MimeType: " image/png "}}, + {Type: gateway.InputPartTypeImage, Media: &gateway.Media{AssetID: " asset-1 ", MimeType: " image/webp "}}, }, Workdir: " /tmp/work ", }) @@ -1450,8 +1562,14 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { if converted.Text != "base\ntext" { t.Fatalf("text = %q, want %q", converted.Text, "base\ntext") } - if len(converted.Images) != 1 || converted.Images[0].Path != "/tmp/a.png" { - t.Fatalf("images = %#v, want one valid image", converted.Images) + if len(converted.Images) != 2 { + t.Fatalf("images = %#v, want two valid images", converted.Images) + } + if converted.Images[0].Path != "/tmp/a.png" || converted.Images[0].MimeType != "image/png" { + t.Fatalf("local image = %#v, want normalized path/mime", converted.Images[0]) + } + if converted.Images[1].AssetID != "asset-1" || converted.Images[1].MimeType != "image/webp" { + t.Fatalf("asset image = %#v, want normalized asset_id/mime", converted.Images[1]) } if got := renderSessionMessageContent(nil); got != "" { @@ -1471,6 +1589,205 @@ func TestConvertGatewayRunInputAndSessionHelpers(t *testing.T) { } } +func TestGatewayRuntimePortBridgeSessionAssets(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := agentsession.NewSQLiteStore(t.TempDir(), workdir) + t.Cleanup(func() { _ = store.Close() }) + session := agentsession.NewWithWorkdir("asset session", workdir) + if _, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: session.ID, + Title: session.Title, + CreatedAt: session.CreatedAt, + UpdatedAt: session.UpdatedAt, + Head: session.HeadSnapshot(), + }); err != nil { + t.Fatalf("CreateSession() error = %v", err) + } + + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + store, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + payload := []byte("image payload") + meta, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: " " + session.ID + " ", + Reader: bytes.NewReader(payload), + MimeType: " image/png ", + }) + if err != nil { + t.Fatalf("SaveSessionAsset() error = %v", err) + } + if meta.SessionID != session.ID || meta.AssetID == "" || meta.MimeType != "image/png" || meta.Size != int64(len(payload)) { + t.Fatalf("unexpected saved meta: %+v", meta) + } + + opened, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: session.ID, + AssetID: " " + meta.AssetID + " ", + }) + if err != nil { + t.Fatalf("OpenSessionAsset() error = %v", err) + } + got, err := io.ReadAll(opened.Reader) + if err != nil { + t.Fatalf("ReadAll() error = %v", err) + } + if string(got) != string(payload) || opened.Meta.AssetID != meta.AssetID || opened.Meta.MimeType != "image/png" { + t.Fatalf("unexpected opened asset meta=%+v payload=%q", opened.Meta, string(got)) + } + if err := opened.Reader.Close(); err != nil { + t.Fatalf("Close opened asset reader: %v", err) + } + + if err := bridge.DeleteSessionAsset(context.Background(), gateway.DeleteSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: session.ID, + AssetID: meta.AssetID, + }); err != nil { + t.Fatalf("DeleteSessionAsset() error = %v", err) + } + if err := bridge.DeleteSessionAsset(context.Background(), gateway.DeleteSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: session.ID, + AssetID: meta.AssetID, + }); err != nil { + t.Fatalf("DeleteSessionAsset() should be idempotent, got %v", err) + } + if _, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: session.ID, + AssetID: meta.AssetID, + }); !errors.Is(err, gateway.ErrRuntimeResourceNotFound) { + t.Fatalf("OpenSessionAsset() after delete error = %v, want resource not found", err) + } +} + +func TestGatewayRuntimePortBridgeSessionAssetErrors(t *testing.T) { + t.Parallel() + + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + testSessionStore, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + if _, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: " ", + Reader: strings.NewReader("x"), + MimeType: "image/png", + }); err == nil { + t.Fatal("expected empty session id save error") + } + if _, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + AssetID: " ", + }); err == nil { + t.Fatal("expected empty asset id open error") + } + if _, err := bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + Reader: strings.NewReader("x"), + MimeType: "image/png", + }); err == nil || !strings.Contains(err.Error(), "asset store is unavailable") { + t.Fatalf("expected unavailable asset store save error, got %v", err) + } + if err := bridge.DeleteSessionAsset(context.Background(), gateway.DeleteSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + AssetID: "asset-1", + }); err == nil || !strings.Contains(err.Error(), "does not support delete") { + t.Fatalf("expected unavailable asset store delete error, got %v", err) + } + if _, err := bridge.OpenSessionAsset(context.Background(), gateway.OpenSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "session-1", + AssetID: "asset-1", + }); err == nil || !strings.Contains(err.Error(), "asset store is unavailable") { + t.Fatalf("expected unavailable asset store open error, got %v", err) + } +} + +func TestGatewayRuntimePortBridgeSessionAssetSaveRequiresExistingSession(t *testing.T) { + t.Parallel() + + store := agentsession.NewSQLiteStore(t.TempDir(), t.TempDir()) + t.Cleanup(func() { _ = store.Close() }) + bridge, err := newGatewayRuntimePortBridge( + context.Background(), + &runtimeStub{eventsCh: make(chan agentruntime.RuntimeEvent, 1)}, + store, + ) + if err != nil { + t.Fatalf("new bridge: %v", err) + } + defer bridge.Close() + + _, err = bridge.SaveSessionAsset(context.Background(), gateway.SaveSessionAssetInput{ + SubjectID: testBridgeSubjectID, + SessionID: "missing-session", + Reader: strings.NewReader("x"), + MimeType: "image/png", + }) + if !errors.Is(err, gateway.ErrRuntimeResourceNotFound) { + t.Fatalf("SaveSessionAsset() missing session error = %v, want resource not found", err) + } +} + +func TestConvertRuntimeSessionToGatewaySessionIncludesCurrentPlan(t *testing.T) { + required := true + session := agentsession.New("plan session") + session.AgentMode = agentsession.AgentModePlan + session.CurrentPlan = &agentsession.PlanArtifact{ + ID: "plan-1", + Revision: 2, + Status: agentsession.PlanStatusDraft, + Spec: agentsession.PlanSpec{ + Goal: "修复 web plan 展示", + Steps: []string{"发事件", "渲染卡片"}, + Constraints: []string{"不创建执行 todo"}, + OpenQuestions: []string{"是否需要审批按钮"}, + Todos: []agentsession.TodoItem{{ + ID: "todo-1", + Content: "legacy todo", + Status: agentsession.TodoStatusPending, + Required: &required, + }}, + }, + Summary: agentsession.SummaryView{ + Goal: "修复 web plan 展示", + KeySteps: []string{"发事件"}, + }, + } + + converted := convertRuntimeSessionToGatewaySession(session) + if converted.CurrentPlan == nil { + t.Fatal("expected current_plan to be present") + } + if converted.CurrentPlan.ID != "plan-1" || converted.CurrentPlan.Spec.Goal != "修复 web plan 展示" { + t.Fatalf("unexpected current_plan: %+v", converted.CurrentPlan) + } + if len(converted.CurrentPlan.Spec.Todos) != 1 || !converted.CurrentPlan.Spec.Todos[0].Required { + t.Fatalf("unexpected plan todos: %+v", converted.CurrentPlan.Spec.Todos) + } +} + func TestGatewayRuntimePortBridgeDeleteSession(t *testing.T) { t.Run("success", func(t *testing.T) { store := &bridgeSessionStoreStub{ diff --git a/internal/cli/root_test.go b/internal/cli/root_test.go index 08ca8bbdc..2f5ebb98c 100644 --- a/internal/cli/root_test.go +++ b/internal/cli/root_test.go @@ -1100,6 +1100,14 @@ func (stubRuntimePort) CreateSession(context.Context, gateway.CreateSessionInput return "", nil } +func (stubRuntimePort) SaveSessionAsset(context.Context, gateway.SaveSessionAssetInput) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (stubRuntimePort) OpenSessionAsset(context.Context, gateway.OpenSessionAssetInput) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (stubRuntimePort) ListSessionTodos(context.Context, gateway.ListSessionTodosInput) (gateway.TodoSnapshot, error) { return gateway.TodoSnapshot{}, nil } diff --git a/internal/cli/web_command.go b/internal/cli/web_command.go index 3722da13d..94e914c78 100644 --- a/internal/cli/web_command.go +++ b/internal/cli/web_command.go @@ -25,6 +25,8 @@ var ( webCommandStartGatewayServer = startGatewayServer webCommandBuildFrontend = buildFrontend webCommandLookPath = exec.LookPath + openBrowserFn = openBrowser + userHomeDirFn = os.UserHomeDir webCommandEmbeddedAssets = func() (fs.FS, bool) { if !webassets.IsAvailable() { return nil, false @@ -327,7 +329,7 @@ func waitForGatewayAndOpenBrowser(ctx context.Context, address string, logger *l browserURL += "/?token=" + token } logger.Printf("gateway is ready, opening browser: %s", baseURL) - if openErr := openBrowser(browserURL); openErr != nil { + if openErr := openBrowserFn(browserURL); openErr != nil { logger.Printf("failed to open browser: %v (open %s manually)", openErr, browserURL) } return @@ -340,7 +342,7 @@ func waitForGatewayAndOpenBrowser(ctx context.Context, address string, logger *l // readGatewayToken 从 ~/.neocode/auth.json 读取认证 token。 func readGatewayToken() string { - homeDir, err := os.UserHomeDir() + homeDir, err := userHomeDirFn() if err != nil { return "" } diff --git a/internal/cli/web_command_test.go b/internal/cli/web_command_test.go index dd48daab4..3f228094b 100644 --- a/internal/cli/web_command_test.go +++ b/internal/cli/web_command_test.go @@ -423,6 +423,13 @@ func TestBuildFrontendAndReadGatewayToken(t *testing.T) { if err := os.WriteFile(filepath.Join(authDir, "auth.json"), authData, 0o644); err != nil { t.Fatalf("write auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) @@ -449,6 +456,13 @@ func TestWaitForGatewayAndOpenBrowserAndResolveListenAddress(t *testing.T) { if err := os.WriteFile(filepath.Join(authDir, "auth.json"), authData, 0o644); err != nil { t.Fatalf("write auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) @@ -474,6 +488,13 @@ func TestWaitForGatewayAndOpenBrowserAndResolveListenAddress(t *testing.T) { t.Cleanup(func() { _ = os.Setenv("PATH", originalPath) }) + originalOpenBrowser := openBrowserFn + openBrowserFn = func(url string) error { + return os.WriteFile(openLog, []byte(url), 0o644) + } + t.Cleanup(func() { + openBrowserFn = originalOpenBrowser + }) server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/healthz" { @@ -543,6 +564,13 @@ func TestResolveWebStaticDirCurrentWorkdirAndReadGatewayTokenInvalid(t *testing. if err := os.WriteFile(filepath.Join(authDir, "auth.json"), []byte("{invalid"), 0o644); err != nil { t.Fatalf("write invalid auth.json: %v", err) } + originalUserHomeDir := userHomeDirFn + userHomeDirFn = func() (string, error) { + return homeDir, nil + } + t.Cleanup(func() { + userHomeDirFn = originalUserHomeDir + }) originalHome := os.Getenv("HOME") if err := os.Setenv("HOME", homeDir); err != nil { t.Fatalf("set HOME: %v", err) diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c6a6cfd50..4db239595 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -1129,15 +1129,11 @@ func TestCompactConfigDefaultsAndRoundTrip(t *testing.T) { compactCfg.ReadTimeMaxMessageSpans, ) } - if compactCfg.MicroCompactDisabled { - t.Fatalf("expected micro compact to be enabled by default") - } cfg.Context.Compact.ManualStrategy = CompactManualStrategyFullReplace cfg.Context.Compact.ManualKeepRecentMessages = 2 cfg.Context.Compact.MaxSummaryChars = 900 cfg.Context.Compact.ReadTimeMaxMessageSpans = 30 - cfg.Context.Compact.MicroCompactDisabled = true if err := loader.Save(context.Background(), cfg); err != nil { t.Fatalf("Save() error = %v", err) } @@ -1152,9 +1148,6 @@ func TestCompactConfigDefaultsAndRoundTrip(t *testing.T) { if strings.Contains(text, "manual_keep_recent_spans:") { t.Fatalf("expected persisted config to drop legacy manual_keep_recent_spans key, got:\n%s", text) } - if !strings.Contains(text, "micro_compact_disabled: true") { - t.Fatalf("expected persisted config to include micro_compact_disabled, got:\n%s", text) - } if !strings.Contains(text, "read_time_max_message_spans: 30") { t.Fatalf("expected persisted config to include read_time_max_message_spans, got:\n%s", text) } @@ -1175,9 +1168,6 @@ func TestCompactConfigDefaultsAndRoundTrip(t *testing.T) { if reloaded.Context.Compact.ReadTimeMaxMessageSpans != 30 { t.Fatalf("expected read_time_max_message_spans=30, got %d", reloaded.Context.Compact.ReadTimeMaxMessageSpans) } - if !reloaded.Context.Compact.MicroCompactDisabled { - t.Fatalf("expected micro_compact_disabled to persist") - } } func TestCompactConfigValidateFailures(t *testing.T) { diff --git a/internal/config/context.go b/internal/config/context.go index 139848bdd..703adecb9 100644 --- a/internal/config/context.go +++ b/internal/config/context.go @@ -13,7 +13,6 @@ const ( DefaultBudgetReserveTokens = 13000 DefaultBudgetFallbackPromptBudget = 100000 DefaultBudgetMaxReactiveCompacts = 3 - DefaultMicroCompactRetainedToolSpans = 6 DefaultCompactReadTimeMaxMessageSpans = 24 DefaultAskMaxInputTokens = 8000 DefaultAskRetainTurns = 5 @@ -30,13 +29,11 @@ type ContextConfig struct { } type CompactConfig struct { - ManualStrategy string `yaml:"manual_strategy,omitempty"` - ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` - MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` - MicroCompactDisabled bool `yaml:"micro_compact_disabled,omitempty"` - MicroCompactRetainedToolSpans int `yaml:"micro_compact_retained_tool_spans,omitempty"` - ReadTimeMaxMessageSpans int `yaml:"read_time_max_message_spans,omitempty"` - MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` + ManualStrategy string `yaml:"manual_strategy,omitempty"` + ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` + MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` + ReadTimeMaxMessageSpans int `yaml:"read_time_max_message_spans,omitempty"` + MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` } // BudgetConfig 定义上下文预算控制面的配置。 @@ -76,11 +73,10 @@ func defaultBudgetConfig() BudgetConfig { // defaultCompactConfig 返回手动 compact 策略的默认配置。 func defaultCompactConfig() CompactConfig { return CompactConfig{ - ManualStrategy: CompactManualStrategyKeepRecent, - ManualKeepRecentMessages: DefaultCompactManualKeepRecentMessages, - MaxSummaryChars: DefaultCompactMaxSummaryChars, - MicroCompactRetainedToolSpans: DefaultMicroCompactRetainedToolSpans, - ReadTimeMaxMessageSpans: DefaultCompactReadTimeMaxMessageSpans, + ManualStrategy: CompactManualStrategyKeepRecent, + ManualKeepRecentMessages: DefaultCompactManualKeepRecentMessages, + MaxSummaryChars: DefaultCompactMaxSummaryChars, + ReadTimeMaxMessageSpans: DefaultCompactReadTimeMaxMessageSpans, } } @@ -143,9 +139,6 @@ func (c *CompactConfig) ApplyDefaults(defaults CompactConfig) { if c.MaxSummaryChars <= 0 { c.MaxSummaryChars = defaults.MaxSummaryChars } - if c.MicroCompactRetainedToolSpans <= 0 { - c.MicroCompactRetainedToolSpans = defaults.MicroCompactRetainedToolSpans - } if c.ReadTimeMaxMessageSpans <= 0 { c.ReadTimeMaxMessageSpans = defaults.ReadTimeMaxMessageSpans } diff --git a/internal/config/loader.go b/internal/config/loader.go index a07a689c9..940703fdd 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -43,13 +43,11 @@ type persistedContextConfig struct { } type persistedCompactConfig struct { - ManualStrategy string `yaml:"manual_strategy,omitempty"` - ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` - MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` - MicroCompactDisabled bool `yaml:"micro_compact_disabled,omitempty"` - MicroCompactRetainedToolSpans int `yaml:"micro_compact_retained_tool_spans,omitempty"` - ReadTimeMaxMessageSpans int `yaml:"read_time_max_message_spans,omitempty"` - MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` + ManualStrategy string `yaml:"manual_strategy,omitempty"` + ManualKeepRecentMessages int `yaml:"manual_keep_recent_messages,omitempty"` + MaxSummaryChars int `yaml:"max_summary_chars,omitempty"` + ReadTimeMaxMessageSpans int `yaml:"read_time_max_message_spans,omitempty"` + MaxArchivedPromptChars int `yaml:"max_archived_prompt_chars,omitempty"` } type persistedBudgetConfig struct { @@ -284,13 +282,11 @@ func marshalPersistedConfig(snapshot Config) ([]byte, error) { func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { return persistedContextConfig{ Compact: persistedCompactConfig{ - ManualStrategy: cfg.Compact.ManualStrategy, - ManualKeepRecentMessages: cfg.Compact.ManualKeepRecentMessages, - MaxSummaryChars: cfg.Compact.MaxSummaryChars, - MicroCompactDisabled: cfg.Compact.MicroCompactDisabled, - MicroCompactRetainedToolSpans: cfg.Compact.MicroCompactRetainedToolSpans, - ReadTimeMaxMessageSpans: cfg.Compact.ReadTimeMaxMessageSpans, - MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, + ManualStrategy: cfg.Compact.ManualStrategy, + ManualKeepRecentMessages: cfg.Compact.ManualKeepRecentMessages, + MaxSummaryChars: cfg.Compact.MaxSummaryChars, + ReadTimeMaxMessageSpans: cfg.Compact.ReadTimeMaxMessageSpans, + MaxArchivedPromptChars: cfg.Compact.MaxArchivedPromptChars, }, Budget: persistedBudgetConfig{ PromptBudget: cfg.Budget.PromptBudget, @@ -310,13 +306,11 @@ func newPersistedContextConfig(cfg ContextConfig) persistedContextConfig { func fromPersistedContextConfig(file persistedContextConfig, defaults ContextConfig) ContextConfig { out := ContextConfig{ Compact: CompactConfig{ - ManualStrategy: strings.TrimSpace(file.Compact.ManualStrategy), - ManualKeepRecentMessages: file.Compact.ManualKeepRecentMessages, - MaxSummaryChars: file.Compact.MaxSummaryChars, - MicroCompactDisabled: file.Compact.MicroCompactDisabled, - MicroCompactRetainedToolSpans: file.Compact.MicroCompactRetainedToolSpans, - ReadTimeMaxMessageSpans: file.Compact.ReadTimeMaxMessageSpans, - MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, + ManualStrategy: strings.TrimSpace(file.Compact.ManualStrategy), + ManualKeepRecentMessages: file.Compact.ManualKeepRecentMessages, + MaxSummaryChars: file.Compact.MaxSummaryChars, + ReadTimeMaxMessageSpans: file.Compact.ReadTimeMaxMessageSpans, + MaxArchivedPromptChars: file.Compact.MaxArchivedPromptChars, }, Budget: BudgetConfig{ PromptBudget: file.Budget.PromptBudget, diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index 65d14c6f9..bc5fda873 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -126,8 +126,9 @@ runtime: priority: 100 timeout_sec: 2 failure_policy: warn_only - params: + match: tool_name: bash + params: message: "bash is called" ` writeLoaderConfig(t, loader, raw) @@ -2105,7 +2106,6 @@ context: manual_strategy: keep_recent manual_keep_recent_messages: 9 max_summary_chars: 900 - micro_compact_retained_tool_spans: 4 max_archived_prompt_chars: 4096 ` writeLoaderConfig(t, loader, raw) @@ -2114,9 +2114,6 @@ context: if err != nil { t.Fatalf("Load() error = %v", err) } - if cfg.Context.Compact.MicroCompactRetainedToolSpans != 4 { - t.Fatalf("expected micro_compact_retained_tool_spans=4, got %d", cfg.Context.Compact.MicroCompactRetainedToolSpans) - } if cfg.Context.Compact.MaxArchivedPromptChars != 4096 { t.Fatalf("expected max_archived_prompt_chars=4096, got %d", cfg.Context.Compact.MaxArchivedPromptChars) } @@ -2127,7 +2124,6 @@ func TestLoaderSaveRoundTripsCompactExtendedFields(t *testing.T) { loader := NewLoader(t.TempDir(), testDefaultConfig()) cfg := loader.DefaultConfig() - cfg.Context.Compact.MicroCompactRetainedToolSpans = 5 cfg.Context.Compact.MaxArchivedPromptChars = 3072 if err := loader.Save(context.Background(), &cfg); err != nil { @@ -2139,9 +2135,6 @@ func TestLoaderSaveRoundTripsCompactExtendedFields(t *testing.T) { t.Fatalf("read config: %v", err) } text := string(data) - if !strings.Contains(text, "micro_compact_retained_tool_spans: 5") { - t.Fatalf("expected persisted micro_compact_retained_tool_spans, got:\n%s", text) - } if !strings.Contains(text, "max_archived_prompt_chars: 3072") { t.Fatalf("expected persisted max_archived_prompt_chars, got:\n%s", text) } @@ -2150,9 +2143,6 @@ func TestLoaderSaveRoundTripsCompactExtendedFields(t *testing.T) { if err != nil { t.Fatalf("Load() error = %v", err) } - if loaded.Context.Compact.MicroCompactRetainedToolSpans != 5 { - t.Fatalf("expected round-trip micro_compact_retained_tool_spans=5, got %d", loaded.Context.Compact.MicroCompactRetainedToolSpans) - } if loaded.Context.Compact.MaxArchivedPromptChars != 3072 { t.Fatalf("expected round-trip max_archived_prompt_chars=3072, got %d", loaded.Context.Compact.MaxArchivedPromptChars) } diff --git a/internal/config/runtime_hooks.go b/internal/config/runtime_hooks.go index b793f3e7b..5eaa3c49f 100644 --- a/internal/config/runtime_hooks.go +++ b/internal/config/runtime_hooks.go @@ -5,6 +5,8 @@ import ( "net" "net/url" "strings" + + "neo-code/internal/runtime/hooks" ) const ( @@ -38,22 +40,6 @@ var runtimeHookExternalKinds = map[string]struct{}{ "agent": {}, } -const ( - runtimeHookPointBeforeToolCall = "before_tool_call" - runtimeHookPointAfterToolResult = "after_tool_result" - runtimeHookPointBeforeCompletionDecision = "before_completion_decision" - runtimeHookPointAcceptGate = "accept_gate" - runtimeHookPointBeforePermissionDecision = "before_permission_decision" - runtimeHookPointAfterToolFailure = "after_tool_failure" - runtimeHookPointSessionStart = "session_start" - runtimeHookPointSessionEnd = "session_end" - runtimeHookPointUserPromptSubmit = "user_prompt_submit" - runtimeHookPointPreCompact = "pre_compact" - runtimeHookPointPostCompact = "post_compact" - runtimeHookPointSubAgentStart = "subagent_start" - runtimeHookPointSubAgentStop = "subagent_stop" -) - const ( runtimeHookHandlerRequireFileExists = "require_file_exists" runtimeHookHandlerWarnOnToolCall = "warn_on_tool_call" @@ -78,6 +64,7 @@ type RuntimeHookItemConfig struct { Kind string `yaml:"kind,omitempty"` Mode string `yaml:"mode,omitempty"` Handler string `yaml:"handler,omitempty"` + Match map[string]any `yaml:"match,omitempty"` Priority int `yaml:"priority,omitempty"` TimeoutSec int `yaml:"timeout_sec,omitempty"` FailurePolicy string `yaml:"failure_policy,omitempty"` @@ -203,6 +190,12 @@ func (c RuntimeHookItemConfig) Clone() RuntimeHookItemConfig { if c.Enabled != nil { cloned.Enabled = boolPtr(*c.Enabled) } + if len(c.Match) > 0 { + cloned.Match = make(map[string]any, len(c.Match)) + for key, value := range c.Match { + cloned.Match[key] = cloneRuntimeHookParamValue(value) + } + } if len(c.Params) > 0 { cloned.Params = make(map[string]any, len(c.Params)) for key, value := range c.Params { @@ -246,25 +239,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if strings.TrimSpace(c.ID) == "" { return fmt.Errorf("id is required") } - point := strings.TrimSpace(c.Point) - switch point { - case runtimeHookPointBeforeToolCall, - runtimeHookPointAfterToolResult, - runtimeHookPointBeforeCompletionDecision, - runtimeHookPointAcceptGate, - runtimeHookPointBeforePermissionDecision, - runtimeHookPointAfterToolFailure, - runtimeHookPointSessionStart, - runtimeHookPointSessionEnd, - runtimeHookPointUserPromptSubmit, - runtimeHookPointPreCompact, - runtimeHookPointPostCompact, - runtimeHookPointSubAgentStart, - runtimeHookPointSubAgentStop: - default: + point := hooks.HookPoint(strings.TrimSpace(c.Point)) + if _, ok := hooks.HookPointCapabilities(point); !ok { return fmt.Errorf("point %q is not supported", c.Point) } - if !runtimeHookPointUserAllowed(point) { + if !hooks.IsUserAllowed(point) { return fmt.Errorf("point %q does not allow user hooks", c.Point) } @@ -307,15 +286,25 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { default: return fmt.Errorf("handler %q is not supported", c.Handler) } - if handler == runtimeHookHandlerWarnOnToolCall && !hasWarnOnToolCallTargets(c.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", c.Handler) + if handler == runtimeHookHandlerWarnOnToolCall && !hooks.HasHookMatcherConfig(c.Match) { + return fmt.Errorf("handler %q requires match", c.Handler) + } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } case runtimeHookKindCommand: if normalizedMode != runtimeHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", c.Mode) } - if strings.TrimSpace(readRuntimeHookParamString(c.Params, "command")) == "" { - return fmt.Errorf("kind command requires params.command") + if err := hooks.ValidateCommandParams(c.Params); err != nil { + return err + } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } case runtimeHookKindHTTP: if normalizedMode != runtimeHookModeObserve { @@ -324,6 +313,11 @@ func (c RuntimeHookItemConfig) Validate(defaultFailurePolicy string) error { if err := validateRuntimeHTTPObserveItem(c, policy); err != nil { return err } + if hooks.HasHookMatcherConfig(c.Match) { + if err := hooks.ValidateHookMatcher(point, c.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } } return nil } @@ -426,35 +420,6 @@ func cloneRuntimeHookParamValue(value any) any { } } -func hasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - toolNameRaw, hasToolName := params["tool_name"] - if hasToolName && strings.TrimSpace(fmt.Sprintf("%v", toolNameRaw)) != "" { - return true - } - toolNamesRaw, hasToolNames := params["tool_names"] - if !hasToolNames || toolNamesRaw == nil { - return false - } - switch typed := toolNamesRaw.(type) { - case []string: - for _, item := range typed { - if strings.TrimSpace(item) != "" { - return true - } - } - case []any: - for _, item := range typed { - if strings.TrimSpace(fmt.Sprintf("%v", item)) != "" { - return true - } - } - } - return false -} - // readRuntimeHookParamString 以兼容方式读取 runtime hook 参数中的字符串值。 func readRuntimeHookParamString(params map[string]any, key string) string { if len(params) == 0 { @@ -471,12 +436,3 @@ func readRuntimeHookParamString(params map[string]any, key string) string { return fmt.Sprintf("%v", typed) } } - -func runtimeHookPointUserAllowed(point string) bool { - switch strings.ToLower(strings.TrimSpace(point)) { - case runtimeHookPointBeforePermissionDecision, runtimeHookPointPreCompact, runtimeHookPointSubAgentStart: - return false - default: - return true - } -} diff --git a/internal/config/runtime_hooks_test.go b/internal/config/runtime_hooks_test.go index f039d8f0e..7c3cd646e 100644 --- a/internal/config/runtime_hooks_test.go +++ b/internal/config/runtime_hooks_test.go @@ -3,6 +3,8 @@ package config import ( "strings" "testing" + + "neo-code/internal/runtime/hooks" ) func TestRuntimeHooksConfigApplyDefaultsAndValidate(t *testing.T) { @@ -46,7 +48,7 @@ func TestRuntimeHooksConfigValidateUnsupportedFields(t *testing.T) { tests := []RuntimeHookItemConfig{ { ID: "bad-scope", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: "repo", Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -54,7 +56,7 @@ func TestRuntimeHooksConfigValidateUnsupportedFields(t *testing.T) { }, { ID: "bad-kind", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: "command", Mode: runtimeHookModeSync, @@ -62,7 +64,7 @@ func TestRuntimeHooksConfigValidateUnsupportedFields(t *testing.T) { }, { ID: "bad-mode", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: "async", @@ -70,7 +72,7 @@ func TestRuntimeHooksConfigValidateUnsupportedFields(t *testing.T) { }, { ID: "bad-handler", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -113,7 +115,7 @@ func TestRuntimeHooksConfigValidateRejectsExternalKindsWithP6LiteMessage(t *test cfg.Items = []RuntimeHookItemConfig{ { ID: "external-kind", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: kind, Mode: runtimeHookModeSync, @@ -144,7 +146,59 @@ func TestRuntimeHooksConfigValidateAllowsCommand(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "accept-command", - Point: runtimeHookPointAcceptGate, + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": []any{"echo", "ok"}}, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateCommandShellMode(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-shell", + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": "echo ok", "shell": true}, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateCommandStringWithoutShellRejected(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-no-shell", + Point: string(hooks.HookPointAcceptGate), Scope: runtimeHookScopeUser, Kind: runtimeHookKindCommand, Mode: runtimeHookModeSync, @@ -154,6 +208,32 @@ func TestRuntimeHooksConfigValidateAllowsCommand(t *testing.T) { }, }, } + if err := cfg.Validate(); err == nil { + t.Fatal("expected error for string command without shell=true") + } +} + +func TestRuntimeHooksConfigValidateCommandArgvMode(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "cmd-argv", + Point: string(hooks.HookPointAcceptGate), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindCommand, + Mode: runtimeHookModeSync, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"command": []string{"echo", "hello"}}, + }, + }, + } if err := cfg.Validate(); err != nil { t.Fatalf("Validate() error = %v", err) } @@ -170,7 +250,7 @@ func TestRuntimeHooksConfigValidateAllowsHTTPObserve(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Params: map[string]any{ @@ -193,7 +273,7 @@ func TestRuntimeHooksConfigValidateRejectsInvalidHTTPObserveConfig(t *testing.T) base := RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -281,7 +361,7 @@ func TestRuntimeHooksConfigValidateRejectsDisallowedUserPoint(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "deny-pre-compact", - Point: runtimeHookPointPreCompact, + Point: string(hooks.HookPointPreCompact), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -308,11 +388,13 @@ func TestRuntimeHooksConfigItemDefaultsAndClone(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "warn-bash", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Handler: runtimeHookHandlerWarnOnToolCall, - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", - "tags": []any{"warn", "tool"}, + }, + Params: map[string]any{ + "tags": []any{"warn", "tool"}, }, }, }, @@ -368,7 +450,7 @@ func TestRuntimeHooksConfigValidateItemFailurePolicy(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "require-readme", - Point: runtimeHookPointBeforeCompletionDecision, + Point: string(hooks.HookPointBeforeCompletionDecision), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -394,7 +476,7 @@ func TestRuntimeHooksConfigValidateWarnOnToolCallRequiresTarget(t *testing.T) { Items: []RuntimeHookItemConfig{ { ID: "warn-missing-target", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -409,6 +491,65 @@ func TestRuntimeHooksConfigValidateWarnOnToolCallRequiresTarget(t *testing.T) { } } +func TestRuntimeHooksConfigValidateWarnOnToolCallAllowsMatchWithoutLegacyTargets(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "warn-with-match", + Point: string(hooks.HookPointBeforeToolCall), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerWarnOnToolCall, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() error = %v", err) + } +} + +func TestRuntimeHooksConfigValidateRejectsUnsupportedMatcherDimensionForPoint(t *testing.T) { + t.Parallel() + + cfg := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + Items: []RuntimeHookItemConfig{ + { + ID: "session-start-match", + Point: string(hooks.HookPointSessionStart), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerAddContextNote, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"note": "observe"}, + Match: map[string]any{ + "tool_name": "bash", + }, + }, + }, + } + if err := cfg.Validate(); err == nil { + t.Fatal("expected unsupported matcher dimension to fail validation") + } +} + func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Parallel() @@ -449,8 +590,8 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { DefaultTimeoutSec: 2, DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, Items: []RuntimeHookItemConfig{ - {ID: "dup", Point: runtimeHookPointBeforeToolCall, Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, Handler: runtimeHookHandlerWarnOnToolCall, TimeoutSec: 1, Params: map[string]any{"tool_name": "bash"}}, - {ID: " DUP ", Point: runtimeHookPointBeforeToolCall, Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, Handler: runtimeHookHandlerWarnOnToolCall, TimeoutSec: 1, Params: map[string]any{"tool_name": "bash"}}, + {ID: "dup", Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, Handler: runtimeHookHandlerWarnOnToolCall, TimeoutSec: 1, Params: map[string]any{"tool_name": "bash"}}, + {ID: " DUP ", Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, Handler: runtimeHookHandlerWarnOnToolCall, TimeoutSec: 1, Params: map[string]any{"tool_name": "bash"}}, }, } if err := cfg.Validate(); err == nil { @@ -465,7 +606,7 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { } item := RuntimeHookItemConfig{ ID: "x", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindBuiltIn, Mode: runtimeHookModeSync, @@ -525,20 +666,18 @@ func TestRuntimeHooksConfigEdgeBranches(t *testing.T) { t.Fatal("expected deep clone for nested map in slice") } - if hasWarnOnToolCallTargets(nil) { - t.Fatal("nil params should be false") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_name": "bash"}) { - t.Fatal("tool_name should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []string{"", "bash"}}) { - t.Fatal("tool_names []string should pass") - } - if !hasWarnOnToolCallTargets(map[string]any{"tool_names": []any{"", "bash"}}) { - t.Fatal("tool_names []any should pass") + matchCfg := RuntimeHookItemConfig{ + Match: map[string]any{ + "tool_name_regex": []any{`^bash$`}, + }, } - if hasWarnOnToolCallTargets(map[string]any{"tool_names": "bash"}) { - t.Fatal("tool_names scalar should fail") + clonedCfg := matchCfg.Clone() + clonedRegexes := clonedCfg.Match["tool_name_regex"].([]any) + clonedRegexes[0] = "^filesystem$" + clonedCfg.Match["tool_name_regex"] = clonedRegexes + originalRegexes := matchCfg.Match["tool_name_regex"].([]any) + if originalRegexes[0] == "^filesystem$" { + t.Fatal("expected match field to be deep-cloned") } }) } @@ -554,7 +693,7 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { } { item := RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -580,7 +719,7 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { name: "invalid absolute url", item: RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -593,7 +732,7 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { name: "headers must be map", item: RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -607,7 +746,7 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { name: "empty header name", item: RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -621,7 +760,7 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { name: "empty header value", item: RuntimeHookItemConfig{ ID: "observe-http", - Point: runtimeHookPointBeforeToolCall, + Point: string(hooks.HookPointBeforeToolCall), Scope: runtimeHookScopeUser, Kind: runtimeHookKindHTTP, Mode: runtimeHookModeObserve, @@ -663,17 +802,73 @@ func TestRuntimeHTTPObserveValidationHelpers(t *testing.T) { if got := readRuntimeHookParamString(map[string]any{"x": 123}, "x"); got != "123" { t.Fatalf("readRuntimeHookParamString(non-string) = %q", got) } - if !runtimeHookPointUserAllowed(runtimeHookPointBeforeToolCall) { + if !hooks.IsUserAllowed(hooks.HookPointBeforeToolCall) { t.Fatal("before_tool_call should allow user hooks") } - for _, point := range []string{ - runtimeHookPointBeforePermissionDecision, - runtimeHookPointPreCompact, - runtimeHookPointSubAgentStart, + for _, point := range []hooks.HookPoint{ + hooks.HookPointBeforePermissionDecision, + hooks.HookPointPreCompact, + hooks.HookPointSubAgentStart, } { - if runtimeHookPointUserAllowed(point) { + if hooks.IsUserAllowed(point) { t.Fatalf("%s should be rejected for user hooks", point) } } }) } + +// TestHookPointSingleSourceConsistency 验证 config 侧与 runtime hooks 包的点位定义一致。 +// 新增 hook point 时只需修改 runtime hooks 包,config 侧自动接受。 +func TestHookPointSingleSourceConsistency(t *testing.T) { + t.Parallel() + + // 所有 runtime hooks 包导出的点位都应被 config 接受。 + allPoints := hooks.ListHookPoints() + if len(allPoints) == 0 { + t.Fatal("expected at least one hook point from runtime hooks package") + } + + base := RuntimeHooksConfig{ + Enabled: boolPtr(true), + UserHooksEnabled: boolPtr(true), + DefaultTimeoutSec: 2, + DefaultFailurePolicy: runtimeHookFailurePolicyWarnOnly, + } + + for _, point := range allPoints { + point := point + t.Run(string(point), func(t *testing.T) { + t.Parallel() + if !hooks.IsUserAllowed(point) { + // 跳过不允许 user 的点位,它们在 config 校验中会被拒绝。 + return + } + cfg := base.Clone() + cfg.Items = []RuntimeHookItemConfig{ + { + ID: "test-" + string(point), + Point: string(point), + Scope: runtimeHookScopeUser, + Kind: runtimeHookKindBuiltIn, + Mode: runtimeHookModeSync, + Handler: runtimeHookHandlerAddContextNote, + TimeoutSec: 2, + FailurePolicy: runtimeHookFailurePolicyWarnOnly, + Params: map[string]any{"note": "consistency check"}, + }, + } + if err := cfg.Validate(); err != nil { + t.Fatalf("config rejected point %q: %v", point, err) + } + }) + } + + // 验证 accept_gate 在 runtime hooks 包中存在且允许 user。 + acceptGateCap, ok := hooks.HookPointCapabilities(hooks.HookPointAcceptGate) + if !ok { + t.Fatal("accept_gate not found in runtime hooks capabilities") + } + if !acceptGateCap.UserAllowed { + t.Fatal("accept_gate should allow user hooks") + } +} diff --git a/internal/context/builder.go b/internal/context/builder.go index 0ab6e2f08..131e8a226 100644 --- a/internal/context/builder.go +++ b/internal/context/builder.go @@ -2,22 +2,17 @@ package context import ( "context" - - providertypes "neo-code/internal/provider/types" - agentsession "neo-code/internal/session" ) // DefaultBuilder preserves the current runtime context-building behavior. type DefaultBuilder struct { stablePromptSources []promptSectionSource dynamicPromptSources []promptSectionSource - promptSources []promptSectionSource trimPolicy messageTrimPolicy - microCompactCfg MicroCompactConfig } // newStablePromptSources 返回稳定提示词来源列表,适合作为缓存前缀。 -// extra 中的非 nil SectionSource 也会追加到 stable 中(如 memo 持久记忆索引)。 +// extra 会追加到 stable 中(如 memo 持久记忆索引)。 func newStablePromptSources(extra ...SectionSource) []promptSectionSource { sources := []promptSectionSource{ corePromptSource{}, @@ -44,52 +39,18 @@ func newDynamicPromptSources() []promptSectionSource { } } -// NewConfiguredBuilder 基于聚合配置和可选 SectionSource 列表构建上下文构建器,是推荐的统一构造入口。 -// cfg.PinChecker 为 nil 时自动使用默认 pin checker;sources 中 nil 元素会被跳过。 -func NewConfiguredBuilder(cfg MicroCompactConfig, sources ...SectionSource) Builder { - if cfg.PinChecker == nil { - cfg.PinChecker = NewDefaultPinChecker() - } +// NewConfiguredBuilder 基于可选 SectionSource 列表构建上下文构建器,是推荐的统一构造入口。 +func NewConfiguredBuilder(sources ...SectionSource) Builder { return &DefaultBuilder{ stablePromptSources: newStablePromptSources(sources...), dynamicPromptSources: newDynamicPromptSources(), trimPolicy: spanMessageTrimPolicy{}, - microCompactCfg: cfg, } } // NewBuilder returns the default context builder implementation. func NewBuilder() Builder { - return NewConfiguredBuilder(MicroCompactConfig{}) -} - -// NewBuilderWithToolPolicies 返回带工具 micro compact 策略源的默认上下文构建器。 -// -// Deprecated: 使用 NewConfiguredBuilder 替代。 -func NewBuilderWithToolPolicies(policies MicroCompactPolicySource) Builder { - return NewConfiguredBuilder(MicroCompactConfig{Policies: policies}) -} - -// NewBuilderWithToolPoliciesAndSummarizers 返回带工具策略与内容摘要器的上下文构建器。 -// -// Deprecated: 使用 NewConfiguredBuilder 替代。 -func NewBuilderWithToolPoliciesAndSummarizers(policies MicroCompactPolicySource, summarizers MicroCompactSummarizerSource) Builder { - return NewConfiguredBuilder(MicroCompactConfig{Policies: policies, Summarizers: summarizers}) -} - -// NewBuilderWithMemo 返回带记忆注入能力的上下文构建器。 -// memoSource 为 nil 时等价于 NewBuilderWithToolPolicies。 -// -// Deprecated: 使用 NewConfiguredBuilder 替代。 -func NewBuilderWithMemo(policies MicroCompactPolicySource, memoSource SectionSource) Builder { - return NewConfiguredBuilder(MicroCompactConfig{Policies: policies}, memoSource) -} - -// NewBuilderWithMemoAndSummarizers 返回带记忆注入与内容摘要器的上下文构建器。 -// -// Deprecated: 使用 NewConfiguredBuilder 替代。 -func NewBuilderWithMemoAndSummarizers(policies MicroCompactPolicySource, summarizers MicroCompactSummarizerSource, memoSource SectionSource) Builder { - return NewConfiguredBuilder(MicroCompactConfig{Policies: policies, Summarizers: summarizers}, memoSource) + return NewConfiguredBuilder() } // collectPromptSections 遍历 promptSectionSource 列表并收集所有 sections。 @@ -111,29 +72,17 @@ func (b *DefaultBuilder) Build(ctx context.Context, input BuildInput) (BuildResu return BuildResult{}, err } - stableSources := b.stablePromptSources - dynamicSources := b.dynamicPromptSources - - // 兼容旧构造方式:如果新字段未设置但旧 promptSources 有内容,回退到旧单列表。 - if len(stableSources) == 0 && len(dynamicSources) == 0 && len(b.promptSources) > 0 { - stableSources = b.promptSources + stableSections, err := collectPromptSections(ctx, input, b.stablePromptSources) + if err != nil { + return BuildResult{}, err } + stablePrompt := composeSystemPrompt(stableSections...) - var stablePrompt, dynamicPrompt string - if stableSources != nil { - stableSections, err := collectPromptSections(ctx, input, stableSources) - if err != nil { - return BuildResult{}, err - } - stablePrompt = composeSystemPrompt(stableSections...) - } - if dynamicSources != nil { - dynamicSections, err := collectPromptSections(ctx, input, dynamicSources) - if err != nil { - return BuildResult{}, err - } - dynamicPrompt = composeSystemPrompt(dynamicSections...) + dynamicSections, err := collectPromptSections(ctx, input, b.dynamicPromptSources) + if err != nil { + return BuildResult{}, err } + dynamicPrompt := composeSystemPrompt(dynamicSections...) systemPrompt := joinSystemPromptParts(stablePrompt, dynamicPrompt) @@ -141,46 +90,13 @@ func (b *DefaultBuilder) Build(ctx context.Context, input BuildInput) (BuildResu if trimPolicy == nil { trimPolicy = spanMessageTrimPolicy{} } - pinChecker := b.microCompactCfg.PinChecker - if pinChecker == nil { - pinChecker = NewDefaultPinChecker() - } return BuildResult{ SystemPrompt: systemPrompt, StableSystemPrompt: stablePrompt, DynamicSystemPrompt: dynamicPrompt, - Messages: applyReadTimeContextProjection( + Messages: projectReadTimeMessagesForModel( trimPolicy.Trim(input.Messages, input.Compact), - input.TaskState, - input.Compact, - b.microCompactCfg.Policies, - b.microCompactCfg.Summarizers, - pinChecker, ), }, nil } - -// applyReadTimeContextProjection 负责在 provider 读取路径上应用只读上下文投影,避免改写原始会话消息。 -func applyReadTimeContextProjection( - messages []providertypes.Message, - taskState agentsession.TaskState, - options CompactOptions, - policies MicroCompactPolicySource, - summarizers MicroCompactSummarizerSource, - pinChecker MicroCompactPinChecker, -) []providertypes.Message { - projectedMessages := cloneContextMessages(messages) - if options.DisableMicroCompact || !taskState.Established() { - return ProjectToolMessagesForModel(projectedMessages) - } - - projectedMessages = microCompactMessagesWithPolicies( - projectedMessages, - policies, - options.MicroCompactRetainedToolSpans, - summarizers, - pinChecker, - ) - return ProjectToolMessagesForModel(projectedMessages) -} diff --git a/internal/context/builder_test.go b/internal/context/builder_test.go index 8f1f01c65..c76fe256f 100644 --- a/internal/context/builder_test.go +++ b/internal/context/builder_test.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "reflect" "strings" "testing" "time" @@ -15,7 +14,6 @@ import ( providertypes "neo-code/internal/provider/types" "neo-code/internal/rules" agentsession "neo-code/internal/session" - "neo-code/internal/tools" ) const maxRetainedMessageSpans = config.DefaultCompactReadTimeMaxMessageSpans @@ -249,29 +247,6 @@ func TestDefaultBuilderBuildIncludesTodosBeforeSystemState(t *testing.T) { } } -func TestNewBuilderWithMemoAndSummarizersIncludesMemoSection(t *testing.T) { - t.Parallel() - - builder := NewBuilderWithMemoAndSummarizers(nil, nil, stubPromptSectionSource{ - sections: []promptSection{ - NewPromptSection("memo", "remember this"), - }, - }) - - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: []providertypes.Message{ - {Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}, - }, - Metadata: testMetadata(t.TempDir()), - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !strings.Contains(got.SystemPrompt, "## memo") { - t.Fatalf("expected memo section in prompt, got %q", got.SystemPrompt) - } -} - func TestDefaultBuilderBuildPlacesRulesBeforeMemo(t *testing.T) { t.Parallel() @@ -288,13 +263,14 @@ func TestDefaultBuilderBuildPlacesRulesBeforeMemo(t *testing.T) { } builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ + stablePromptSources: []promptSectionSource{ corePromptSource{}, newRulesPromptSource(rules.NewLoader(baseDir)), stubPromptSectionSource{sections: []promptSection{{Title: "Memo", Content: "remember this"}}}, + }, + dynamicPromptSources: []promptSectionSource{ &systemStateSource{}, }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, } got, err := builder.Build(stdcontext.Background(), BuildInput{ @@ -338,18 +314,15 @@ func TestDefaultBuilderBuildUsesSpanTrimPolicyWhenTrimPolicyIsUnset(t *testing.T } builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ + stablePromptSources: []promptSectionSource{ stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, } got, err := builder.Build(stdcontext.Background(), BuildInput{ Messages: messages, TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 2, - }, + Compact: CompactOptions{}, }) if err != nil { t.Fatalf("Build() error = %v", err) @@ -366,7 +339,7 @@ func TestDefaultBuilderBuildReturnsPromptSourceError(t *testing.T) { t.Parallel() builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ + stablePromptSources: []promptSectionSource{ stubPromptSectionSource{err: fmt.Errorf("source failed")}, }, } @@ -377,190 +350,10 @@ func TestDefaultBuilderBuildReturnsPromptSourceError(t *testing.T) { } } -func TestDefaultBuilderBuildAppliesMicroCompactAfterTrim(t *testing.T) { - t.Parallel() - - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 2, - }, - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if len(got.Messages) != len(messages) { - t.Fatalf("expected builder output to keep message count, got %d want %d", len(got.Messages), len(messages)) - } - if !strings.Contains(renderDisplayParts(got.Messages[2].Parts), "[summary] filesystem_read_file") { - t.Fatalf("expected builder output to summarize older tool result, got %q", renderDisplayParts(got.Messages[2].Parts)) - } - if renderDisplayParts(got.Messages[4].Parts) != "recent bash result" { - t.Fatalf("expected recent tool result to stay visible, got %q", renderDisplayParts(got.Messages[4].Parts)) - } - if renderDisplayParts(got.Messages[6].Parts) != "latest webfetch result" { - t.Fatalf("expected latest tool result to stay visible, got %q", renderDisplayParts(got.Messages[6].Parts)) - } -} - -func TestDefaultBuilderBuildDefaultsPinCheckerForLiteralBuilder(t *testing.T) { +func TestNewBuilder(t *testing.T) { t.Parallel() - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_write_file", Arguments: `{"path":"README.md"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("README content")}, - ToolMetadata: map[string]string{ - "path": "/project/README.md", - }, - }, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 1, - }, - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - projectedText := renderDisplayParts(got.Messages[2].Parts) - if projectedText == microCompactClearedMessage { - t.Fatalf("expected pinned README result to avoid cleared placeholder, got %q", projectedText) - } - if !strings.Contains(projectedText, "README content") { - t.Fatalf("expected pinned README result to retain content, got %q", projectedText) - } -} - -func TestDefaultBuilderBuildRespectsExplicitPinCheckerOverride(t *testing.T) { - t.Parallel() - - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{PinChecker: noopPinChecker{}}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_write_file", Arguments: `{"path":"README.md"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("README content")}, - ToolMetadata: map[string]string{ - "path": "/project/README.md", - }, - }, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 1, - }, - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !strings.Contains(renderDisplayParts(got.Messages[2].Parts), "[summary] filesystem_write_file") { - t.Fatalf("expected explicit noop pin checker to allow compaction into summary, got %q", renderDisplayParts(got.Messages[2].Parts)) - } -} - -type noopPinChecker struct{} - -func (noopPinChecker) ShouldPin(string, map[string]string) bool { return false } - -func TestNewBuilderWithToolPoliciesAndSummarizers(t *testing.T) { - t.Parallel() - - builder := NewBuilderWithToolPoliciesAndSummarizers( - nil, - stubMicroCompactSummarizerSource{ - "filesystem_read_file": func(content string, metadata map[string]string, isError bool) string { - return "[summary] read_file" - }, - }, - ) + builder := NewBuilder() messages := []providertypes.Message{ {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, @@ -591,204 +384,21 @@ func TestNewBuilderWithToolPoliciesAndSummarizers(t *testing.T) { got, err := builder.Build(stdcontext.Background(), BuildInput{ Messages: messages, TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 2, - }, - Metadata: testMetadata(t.TempDir()), + Compact: CompactOptions{}, + Metadata: testMetadata(t.TempDir()), }) if err != nil { t.Fatalf("Build() error = %v", err) } - const summarizedMessageIndex = 2 - if renderDisplayParts(got.Messages[summarizedMessageIndex].Parts) != "[summary] read_file" { + const olderReadIndex = 2 + if renderDisplayParts(got.Messages[olderReadIndex].Parts) != "old read result" { t.Fatalf( - "expected summarized older read result, got %q", - renderDisplayParts(got.Messages[summarizedMessageIndex].Parts), + "expected older read result content, got %q", + renderDisplayParts(got.Messages[olderReadIndex].Parts), ) } } -func TestDefaultBuilderBuildSkipsMicroCompactWithoutEstablishedTaskState(t *testing.T) { - t.Parallel() - - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{Messages: messages}) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if renderDisplayParts(got.Messages[2].Parts) != "old read result" { - t.Fatalf("expected old tool result to remain visible without task state, got %q", renderDisplayParts(got.Messages[2].Parts)) - } -} - -func TestDefaultBuilderBuildSkipsMicroCompactWhenDisabled(t *testing.T) { - t.Parallel() - - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - Compact: CompactOptions{ - DisableMicroCompact: true, - }, - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !reflect.DeepEqual(got.Messages, messages) { - t.Fatalf("expected messages to remain unchanged when micro compact is disabled, got %+v", got.Messages) - } - if &got.Messages[2] == &messages[2] { - t.Fatalf("expected disabled path to still clone message slice") - } -} - -func TestDefaultBuilderBuildHonorsToolMicroCompactPolicies(t *testing.T) { - t.Parallel() - - builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Stub", Content: "body"}}}, - }, - microCompactCfg: MicroCompactConfig{ - Policies: stubMicroCompactPolicySource{"custom_tool": tools.MicroCompactPolicyPreserveHistory}, - PinChecker: NewDefaultPinChecker(), - }, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "custom_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old custom result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{Messages: messages}) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if renderDisplayParts(got.Messages[2].Parts) != "old custom result" { - t.Fatalf("expected preserved tool result to remain, got %q", renderDisplayParts(got.Messages[2].Parts)) - } -} - -func TestNewBuilderWithToolPoliciesUsesProvidedPolicySource(t *testing.T) { - t.Parallel() - - builder := NewBuilderWithToolPolicies(stubMicroCompactPolicySource{ - "custom_tool": tools.MicroCompactPolicyPreserveHistory, - }) - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "custom_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old custom result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got, err := builder.Build(stdcontext.Background(), BuildInput{Messages: messages}) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if renderDisplayParts(got.Messages[2].Parts) != "old custom result" { - t.Fatalf("expected preserved tool result to remain, got %q", renderDisplayParts(got.Messages[2].Parts)) - } -} - func TestTrimMessagesPreservesToolPairs(t *testing.T) { t.Parallel() @@ -996,14 +606,11 @@ func TestTrimMessagesBoundaries(t *testing.T) { } } -func TestNewBuilderWithMemo(t *testing.T) { +func TestNewConfiguredBuilder(t *testing.T) { t.Parallel() - t.Run("with memo source injects memo section", func(t *testing.T) { - memoSource := stubPromptSectionSource{ - sections: []promptSection{{Title: "Memo", Content: "- [user] test entry"}}, - } - builder := NewBuilderWithMemo(stubMicroCompactPolicySource{}, memoSource) + t.Run("empty config defaults pin checker", func(t *testing.T) { + builder := NewConfiguredBuilder() input := BuildInput{ Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, Metadata: testMetadata(t.TempDir()), @@ -1012,16 +619,16 @@ func TestNewBuilderWithMemo(t *testing.T) { if err != nil { t.Fatalf("Build() error = %v", err) } - if !strings.Contains(result.SystemPrompt, "## Memo") { - t.Errorf("expected Memo section in system prompt") - } - if !strings.Contains(result.SystemPrompt, "test entry") { - t.Errorf("expected memo content in system prompt") + if result.SystemPrompt == "" { + t.Fatalf("expected non-empty system prompt") } }) - t.Run("nil memo source skips memo section", func(t *testing.T) { - builder := NewBuilderWithMemo(stubMicroCompactPolicySource{}, nil) + t.Run("with extra section sources", func(t *testing.T) { + extraSource := stubPromptSectionSource{ + sections: []promptSection{{Title: "Custom", Content: "custom section body"}}, + } + builder := NewConfiguredBuilder(extraSource) input := BuildInput{ Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, Metadata: testMetadata(t.TempDir()), @@ -1030,150 +637,54 @@ func TestNewBuilderWithMemo(t *testing.T) { if err != nil { t.Fatalf("Build() error = %v", err) } - if strings.Contains(result.SystemPrompt, "## Memo") { - t.Error("nil memo source should not inject Memo section") + if !strings.Contains(result.SystemPrompt, "## Custom") { + t.Errorf("expected Custom section in system prompt") + } + if !strings.Contains(result.SystemPrompt, "custom section body") { + t.Errorf("expected custom section content in system prompt") } }) -} - -func TestNewConfiguredBuilder(t *testing.T) { - t.Parallel() - t.Run("empty config defaults pin checker", func(t *testing.T) { - builder := NewConfiguredBuilder(MicroCompactConfig{}) + t.Run("nil extra source is safely ignored", func(t *testing.T) { + t.Parallel() + builder := NewConfiguredBuilder(nil) input := BuildInput{ Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, Metadata: testMetadata(t.TempDir()), } result, err := builder.Build(stdcontext.Background(), input) if err != nil { - t.Fatalf("Build() error = %v", err) + t.Fatalf("Build() with nil source error = %v", err) } if result.SystemPrompt == "" { - t.Fatalf("expected non-empty system prompt") + t.Fatal("expected non-empty system prompt even with nil extra source") } }) - t.Run("with policies and summarizers", func(t *testing.T) { - cfg := MicroCompactConfig{ - Policies: stubMicroCompactPolicySource{}, - Summarizers: stubMicroCompactSummarizerSource{ - "filesystem_read_file": func(content string, metadata map[string]string, isError bool) string { - return "[summary] read_file" - }, - }, - } - builder := NewConfiguredBuilder(cfg) - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 2, - }, - Metadata: testMetadata(t.TempDir()), - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if renderDisplayParts(got.Messages[2].Parts) != "[summary] read_file" { - t.Fatalf("expected summarized older read result, got %q", renderDisplayParts(got.Messages[2].Parts)) - } - }) - - t.Run("with custom pin checker", func(t *testing.T) { - cfg := MicroCompactConfig{ - PinChecker: noopPinChecker{}, - } - builder := NewConfiguredBuilder(cfg) - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_write_file", Arguments: `{"path":"README.md"}`}, - }, - }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-1", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("README content")}, - ToolMetadata: map[string]string{"path": "/project/README.md"}, - }, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - got, err := builder.Build(stdcontext.Background(), BuildInput{ - Messages: messages, - TaskState: agentsession.TaskState{Goal: "keep implementing task"}, - Compact: CompactOptions{ - MicroCompactRetainedToolSpans: 1, - }, - }) - if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !strings.Contains(renderDisplayParts(got.Messages[2].Parts), "[summary] filesystem_write_file") { - t.Fatalf("expected noop pin checker to allow compaction into summary, got %q", renderDisplayParts(got.Messages[2].Parts)) - } - }) - - t.Run("with extra section sources", func(t *testing.T) { - extraSource := stubPromptSectionSource{ - sections: []promptSection{{Title: "Custom", Content: "custom section body"}}, - } - builder := NewConfiguredBuilder(MicroCompactConfig{}, extraSource) + t.Run("mixed nil and valid extra sources", func(t *testing.T) { + t.Parallel() + builder := NewConfiguredBuilder(nil, stubPromptSectionSource{ + sections: []promptSection{{Title: "Valid", Content: "valid section"}}, + }, nil) input := BuildInput{ Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, Metadata: testMetadata(t.TempDir()), } result, err := builder.Build(stdcontext.Background(), input) if err != nil { - t.Fatalf("Build() error = %v", err) - } - if !strings.Contains(result.SystemPrompt, "## Custom") { - t.Errorf("expected Custom section in system prompt") + t.Fatalf("Build() with mixed nil/valid sources error = %v", err) } - if !strings.Contains(result.SystemPrompt, "custom section body") { - t.Errorf("expected custom section content in system prompt") + if !strings.Contains(result.SystemPrompt, "## Valid") { + t.Fatal("expected valid section to be present while nil sources are ignored") } }) - t.Run("nil section sources are skipped", func(t *testing.T) { - builder := NewConfiguredBuilder(MicroCompactConfig{}, nil, stubPromptSectionSource{ + t.Run("multiple extra section sources are appended", func(t *testing.T) { + builder := NewConfiguredBuilder(stubPromptSectionSource{ + sections: []promptSection{{Title: "First", Content: "first body"}}, + }, stubPromptSectionSource{ sections: []promptSection{{Title: "Extra", Content: "extra body"}}, - }, nil) + }) input := BuildInput{ Messages: []providertypes.Message{{Role: "user", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}}, Metadata: testMetadata(t.TempDir()), @@ -1185,6 +696,9 @@ func TestNewConfiguredBuilder(t *testing.T) { if !strings.Contains(result.SystemPrompt, "## Extra") { t.Errorf("expected Extra section in system prompt") } + if !strings.Contains(result.SystemPrompt, "## First") { + t.Errorf("expected First section in system prompt") + } }) } @@ -1259,6 +773,66 @@ func TestDefaultBuilderBuildProjectsMetadataOnlyToolResult(t *testing.T) { } } +func TestDefaultBuilderBuildBoundsProjectedToolContentWithoutMutatingInput(t *testing.T) { + t.Parallel() + + head := strings.Repeat("A", recentWindowToolContentHeadChars) + middle := strings.Repeat("M", 64) + tail := strings.Repeat("B", recentWindowToolContentTailChars-len("TAIL-MARKER")) + "TAIL-MARKER" + rawBody := head + middle + tail + originalMetadata := map[string]string{"tool_name": "bash", "workdir": "D:/project"} + messages := []providertypes.Message{ + {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("run command")}}, + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "bash", Arguments: `{}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Parts: []providertypes.ContentPart{providertypes.NewTextPart(rawBody)}, + ToolMetadata: originalMetadata, + }, + } + + result, err := NewBuilder().Build(stdcontext.Background(), BuildInput{ + Messages: messages, + Metadata: testMetadata(t.TempDir()), + }) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + projectedText := renderDisplayParts(result.Messages[2].Parts) + if !strings.Contains(projectedText, "tool result") || + !strings.Contains(projectedText, "tool: bash") || + !strings.Contains(projectedText, "content_excerpt:") { + t.Fatalf("expected structured excerpted tool result, got %q", projectedText) + } + if strings.Contains(projectedText, "\ncontent:\n") { + t.Fatalf("expected full content marker to be removed, got %q", projectedText) + } + if strings.Contains(projectedText, middle) { + t.Fatalf("expected middle payload to be truncated, got %q", projectedText) + } + if !strings.Contains(projectedText, truncatedExcerptMarker) || + !strings.Contains(projectedText, "TAIL-MARKER") || + !strings.Contains(projectedText, contentTruncatedForModelContext) { + t.Fatalf("expected head/tail truncation markers, got %q", projectedText) + } + if result.Messages[2].ToolMetadata != nil { + t.Fatalf("expected projected metadata to be cleared, got %#v", result.Messages[2].ToolMetadata) + } + if renderDisplayParts(messages[2].Parts) != rawBody { + t.Fatalf("expected source tool content unchanged, got %q", renderDisplayParts(messages[2].Parts)) + } + if messages[2].ToolMetadata["tool_name"] != "bash" || messages[2].ToolMetadata["workdir"] != "D:/project" { + t.Fatalf("expected source tool metadata unchanged, got %#v", messages[2].ToolMetadata) + } +} + func TestDefaultBuilderBuildReturnsStableAndDynamicPrompts(t *testing.T) { t.Parallel() @@ -1338,7 +912,7 @@ func TestDefaultBuilderBuildTodoChangeDoesNotChangeStablePrompt(t *testing.T) { func TestDefaultBuilderBuildMemoIsStable(t *testing.T) { t.Parallel() - builder := NewConfiguredBuilder(MicroCompactConfig{}, stubPromptSectionSource{ + builder := NewConfiguredBuilder(stubPromptSectionSource{ sections: []promptSection{ NewPromptSection("memo", "remember this"), }, @@ -1362,24 +936,30 @@ func TestDefaultBuilderBuildMemoIsStable(t *testing.T) { } } -func TestDefaultBuilderBuildStableAndDynamicPreservesBackwardCompat(t *testing.T) { +func TestDefaultBuilderBuildStableAndDynamicFields(t *testing.T) { t.Parallel() builder := &DefaultBuilder{ - promptSources: []promptSectionSource{ - stubPromptSectionSource{sections: []promptSection{{Title: "Old", Content: "old style"}}}, + stablePromptSources: []promptSectionSource{ + stubPromptSectionSource{sections: []promptSection{{Title: "Stable", Content: "stable style"}}}, + }, + dynamicPromptSources: []promptSectionSource{ + stubPromptSectionSource{sections: []promptSection{{Title: "Dynamic", Content: "dynamic style"}}}, }, - microCompactCfg: MicroCompactConfig{PinChecker: NewDefaultPinChecker()}, + trimPolicy: spanMessageTrimPolicy{}, } result, err := builder.Build(stdcontext.Background(), BuildInput{}) if err != nil { t.Fatalf("Build() error = %v", err) } - if !strings.Contains(result.SystemPrompt, "old style") { - t.Fatalf("expected old style content in system prompt, got %q", result.SystemPrompt) + if !strings.Contains(result.SystemPrompt, "stable style") || !strings.Contains(result.SystemPrompt, "dynamic style") { + t.Fatalf("expected stable and dynamic content in system prompt, got %q", result.SystemPrompt) + } + if !strings.Contains(result.StableSystemPrompt, "stable style") { + t.Fatalf("expected stable content in StableSystemPrompt, got %q", result.StableSystemPrompt) } - if !strings.Contains(result.StableSystemPrompt, "old style") { - t.Fatalf("expected old style content in StableSystemPrompt, got %q", result.StableSystemPrompt) + if !strings.Contains(result.DynamicSystemPrompt, "dynamic style") { + t.Fatalf("expected dynamic content in DynamicSystemPrompt, got %q", result.DynamicSystemPrompt) } } diff --git a/internal/context/microcompact.go b/internal/context/microcompact.go deleted file mode 100644 index 050a543ee..000000000 --- a/internal/context/microcompact.go +++ /dev/null @@ -1,331 +0,0 @@ -package context - -import ( - "strconv" - "strings" - "unicode/utf8" - - "neo-code/internal/config" - "neo-code/internal/context/internalcompact" - providertypes "neo-code/internal/provider/types" - "neo-code/internal/tools" -) - -const ( - // microCompactClearedMessage 是旧工具结果被读时微压缩后的占位符文本。 - microCompactClearedMessage = "[Old tool result content cleared]" - // microCompactSummaryMaxRunes 是摘要回灌到上下文前允许的最大 rune 数量。 - microCompactSummaryMaxRunes = 200 -) - -// microCompactMessages 对裁剪后的消息做只读投影式微压缩,优先摘要旧工具结果,失败时回退清理占位。 -func microCompactMessages(messages []providertypes.Message) []providertypes.Message { - return microCompactMessagesWithPolicies(messages, nil, 0, nil, nil) -} - -// microCompactMessagesWithPolicies 按工具策略对裁剪后的消息做只读投影式微压缩。 -// 仅对需要压缩的工具消息做深拷贝,其余消息共享原始引用以减少内存分配。 -func microCompactMessagesWithPolicies(messages []providertypes.Message, policies MicroCompactPolicySource, retainedToolSpans int, summarizers MicroCompactSummarizerSource, pinChecker MicroCompactPinChecker) []providertypes.Message { - if retainedToolSpans <= 0 { - retainedToolSpans = config.DefaultMicroCompactRetainedToolSpans - } - - if len(messages) == 0 { - return nil - } - - spans := internalcompact.BuildMessageSpans(messages) - protectedStart, hasProtectedTail := internalcompact.ProtectedTailStart(spans) - retainedCompactableSpans := 0 - - modifiedIndices := make(map[int]struct{}) - var pendingCompactions []compactionPending - - for spanIndex := len(spans) - 1; spanIndex >= 0; spanIndex-- { - span := spans[spanIndex] - if hasProtectedTail && span.Start >= protectedStart { - continue - } - if !isToolCallSpan(messages, span) { - continue - } - - compactableIDs, toolNames := compactableToolCallIDs(messages[span.Start].ToolCalls, policies) - if len(compactableIDs) == 0 { - continue - } - if retainedCompactableSpans < retainedToolSpans { - if hasCompactableToolMessage(messages, span, compactableIDs, toolNames, pinChecker) { - retainedCompactableSpans++ - } - continue - } - - compactableContents := compactableToolMessageContents(messages, span, compactableIDs, toolNames, pinChecker) - if len(compactableContents) == 0 { - continue - } - - for messageIndex, content := range compactableContents { - modifiedIndices[messageIndex] = struct{}{} - pendingCompactions = append(pendingCompactions, compactionPending{ - index: messageIndex, - content: content, - toolNames: toolNames, - }) - } - } - - if len(modifiedIndices) == 0 { - return append([]providertypes.Message(nil), messages...) - } - - cloned := make([]providertypes.Message, len(messages)) - for i, msg := range messages { - if _, needsClone := modifiedIndices[i]; needsClone { - cloned[i] = cloneSingleMessage(msg) - } else { - cloned[i] = msg - } - } - - for _, pending := range pendingCompactions { - summary := summarizeOrClear(cloned[pending.index], pending.content, pending.toolNames, summarizers) - cloned[pending.index].Parts = []providertypes.ContentPart{providertypes.NewTextPart(summary)} - } - - return cloned -} - -// compactionPending 记录待压缩的消息索引和所需上下文。 -type compactionPending struct { - index int - content string - toolNames map[string]string -} - -// cloneContextMessages 深拷贝消息切片,避免读时投影污染 runtime 持有的原始会话消息。 -func cloneContextMessages(messages []providertypes.Message) []providertypes.Message { - if len(messages) == 0 { - return nil - } - - cloned := make([]providertypes.Message, 0, len(messages)) - for _, message := range messages { - cloned = append(cloned, cloneSingleMessage(message)) - } - return cloned -} - -// cloneSingleMessage 深拷贝单条消息,隔离 ToolCalls 和 ToolMetadata 的底层引用。 -func cloneSingleMessage(msg providertypes.Message) providertypes.Message { - next := msg - next.ToolCalls = append([]providertypes.ToolCall(nil), msg.ToolCalls...) - if len(msg.ToolMetadata) > 0 { - next.ToolMetadata = make(map[string]string, len(msg.ToolMetadata)) - for key, value := range msg.ToolMetadata { - next.ToolMetadata[key] = value - } - } - return next -} - -// isToolCallSpan 判断当前 span 是否是由 assistant tool call 起始的原子工具块。 -func isToolCallSpan(messages []providertypes.Message, span internalcompact.MessageSpan) bool { - if span.Start < 0 || span.Start >= len(messages) { - return false - } - message := messages[span.Start] - return message.Role == providertypes.RoleAssistant && len(message.ToolCalls) > 0 -} - -// compactableToolCallIDs 返回 assistant tool call 中可参与微压缩的调用 ID 集合及对应的工具名映射。 -func compactableToolCallIDs(calls []providertypes.ToolCall, policies MicroCompactPolicySource) (map[string]struct{}, map[string]string) { - if len(calls) == 0 { - return nil, nil - } - - ids := make(map[string]struct{}, len(calls)) - toolNames := make(map[string]string, len(calls)) - for _, call := range calls { - toolName := strings.TrimSpace(call.Name) - if !toolParticipatesInMicroCompact(toolName, policies) { - continue - } - callID := strings.TrimSpace(call.ID) - if callID == "" { - continue - } - ids[callID] = struct{}{} - toolNames[callID] = toolName - } - if len(ids) == 0 { - return nil, nil - } - return ids, toolNames -} - -// toolParticipatesInMicroCompact 判断工具是否应参与 micro compact;未知工具默认视为可压缩。 -func toolParticipatesInMicroCompact(toolName string, policies MicroCompactPolicySource) bool { - if policies == nil { - return true - } - return policies.MicroCompactPolicy(toolName) != tools.MicroCompactPolicyPreserveHistory -} - -// compactableToolMessageContents 收集工具块中可压缩消息的渲染内容,跳过被钉住的结果。 -func compactableToolMessageContents(messages []providertypes.Message, span internalcompact.MessageSpan, compactableIDs map[string]struct{}, toolNames map[string]string, pinChecker MicroCompactPinChecker) map[int]string { - var contents map[int]string - for messageIndex := span.Start + 1; messageIndex < span.End; messageIndex++ { - content, ok := isCompactableToolMessage(messages[messageIndex], compactableIDs, toolNames, pinChecker) - if !ok { - continue - } - if contents == nil { - contents = make(map[int]string) - } - contents[messageIndex] = content - } - return contents -} - -// hasCompactableToolMessage 判断工具块中是否存在至少一条可压缩且未被钉住的工具消息。 -func hasCompactableToolMessage(messages []providertypes.Message, span internalcompact.MessageSpan, compactableIDs map[string]struct{}, toolNames map[string]string, pinChecker MicroCompactPinChecker) bool { - for messageIndex := span.Start + 1; messageIndex < span.End; messageIndex++ { - if _, ok := isCompactableToolMessage(messages[messageIndex], compactableIDs, toolNames, pinChecker); ok { - return true - } - } - return false -} - -// isCompactableToolMessage 判断工具消息是否可压缩(非保留策略且未被钉住),返回渲染内容和是否可压缩。 -func isCompactableToolMessage(message providertypes.Message, compactableIDs map[string]struct{}, toolNames map[string]string, pinChecker MicroCompactPinChecker) (string, bool) { - content, ok := compactableToolMessageContent(message, compactableIDs) - if !ok { - return "", false - } - callID := strings.TrimSpace(message.ToolCallID) - toolName := toolNameFromCallID(callID, toolNames) - if isPinnedToolMessage(toolName, message.ToolMetadata, pinChecker) { - return "", false - } - return content, true -} - -// compactableToolMessageContent 判断 tool 消息是否可压缩,并返回渲染后的内容文本。 -func compactableToolMessageContent(message providertypes.Message, compactableIDs map[string]struct{}) (string, bool) { - if message.Role != providertypes.RoleTool || message.IsError { - return "", false - } - callID := strings.TrimSpace(message.ToolCallID) - if _, ok := compactableIDs[callID]; !ok { - return "", false - } - - content := strings.TrimSpace(renderDisplayParts(message.Parts)) - if content == "" || content == microCompactClearedMessage { - return "", false - } - return content, true -} - -// summarizeOrClear 为单条可压缩工具消息生成摘要或回退到默认清除占位。 -func summarizeOrClear( - message providertypes.Message, - content string, - toolNames map[string]string, - summarizers MicroCompactSummarizerSource, -) string { - callID := strings.TrimSpace(message.ToolCallID) - toolName, ok := toolNames[callID] - if !ok { - return microCompactClearedMessage - } - - if summarizers != nil { - summarizer := summarizers.MicroCompactSummarizer(toolName) - if summarizer != nil { - summary := summarizer(content, message.ToolMetadata, message.IsError) - if summary != "" { - summary = sanitizeMicroCompactSummary(summary) - if summary != "" { - return summary - } - } - } - } - - summary := sanitizeMicroCompactSummary(fallbackSummary(toolName, content)) - if summary == "" { - return microCompactClearedMessage - } - return summary -} - -// fallbackSummary 为缺少专用摘要器的工具生成最小可读摘要,避免静默清空历史。 -func fallbackSummary(toolName string, content string) string { - trimmedName := strings.TrimSpace(toolName) - if trimmedName == "" { - return "" - } - - parts := []string{ - "[summary]", - trimmedName, - "lines=" + strconv.Itoa(stableLineCount(content)), - "chars=" + strconv.Itoa(utf8.RuneCountInString(content)), - } - return strings.Join(parts, " ") -} - -// stableLineCount 统计文本行数;空文本返回 0,末尾换行不会产生额外空行计数。 -func stableLineCount(text string) int { - if text == "" { - return 0 - } - count := strings.Count(text, "\n") + 1 - if strings.HasSuffix(text, "\n") { - count-- - } - if count < 0 { - return 0 - } - return count -} - -// sanitizeMicroCompactSummary 对 summarizer 输出做最终净化与限长,避免把不安全文本直接回灌上下文。 -func sanitizeMicroCompactSummary(summary string) string { - trimmed := strings.TrimSpace(summary) - if trimmed == "" { - return "" - } - - var b strings.Builder - b.Grow(len(trimmed)) - for _, r := range trimmed { - if r < 32 || r == 127 { - continue - } - b.WriteRune(r) - } - - clean := strings.TrimSpace(b.String()) - if clean == "" { - return "" - } - return truncateSummaryRunes(clean, microCompactSummaryMaxRunes) -} - -// truncateSummaryRunes 按 rune 数量截断摘要,超限时追加 "..."。 -func truncateSummaryRunes(summary string, maxRunes int) string { - if maxRunes <= 0 || summary == "" { - return summary - } - - runes := []rune(summary) - if len(runes) <= maxRunes { - return summary - } - return string(runes[:maxRunes]) + "..." -} diff --git a/internal/context/microcompact_summarizer_test.go b/internal/context/microcompact_summarizer_test.go deleted file mode 100644 index e1b010e8c..000000000 --- a/internal/context/microcompact_summarizer_test.go +++ /dev/null @@ -1,416 +0,0 @@ -package context - -import ( - "strings" - "testing" - "unicode/utf8" - - "neo-code/internal/context/internalcompact" - providertypes "neo-code/internal/provider/types" - "neo-code/internal/tools" -) - -// stubMicroCompactSummarizerSource 实现 MicroCompactSummarizerSource,用于测试。 -type stubMicroCompactSummarizerSource map[string]tools.ContentSummarizer - -func (s stubMicroCompactSummarizerSource) MicroCompactSummarizer(name string) tools.ContentSummarizer { - return s[name] -} - -// TestMicroCompactWithSummarizerProducesSummary 验证注册 summarizer 的工具生成摘要而非清除占位。 -func TestMicroCompactWithSummarizerProducesSummary(t *testing.T) { - t.Parallel() - - bashSummarizer := func(content string, metadata map[string]string, isError bool) string { - return "[summary] bash: " + content - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got := microCompactMessagesWithPolicies( - messages, - stubMicroCompactPolicySource{}, - 2, - stubMicroCompactSummarizerSource{"bash": bashSummarizer}, - nil, - ) - - if renderDisplayParts(got[2].Parts) == microCompactClearedMessage { - t.Fatalf("expected summarized content for old bash result, got cleared placeholder") - } - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] bash:") { - t.Fatalf("expected summary prefix, got %q", renderDisplayParts(got[2].Parts)) - } - if renderDisplayParts(got[4].Parts) != "recent bash result" { - t.Fatalf("expected recent bash result retained, got %q", renderDisplayParts(got[4].Parts)) - } - if renderDisplayParts(got[6].Parts) != "latest bash result" { - t.Fatalf("expected latest bash result retained, got %q", renderDisplayParts(got[6].Parts)) - } - // 原始切片不被修改 - if renderDisplayParts(messages[2].Parts) != "old bash result" { - t.Fatalf("expected original slice unchanged, got %q", renderDisplayParts(messages[2].Parts)) - } -} - -// TestMicroCompactWithoutSummarizerFallsBackToSummary 验证未注册 summarizer 的工具使用通用兜底摘要。 -func TestMicroCompactWithoutSummarizerFallsBackToSummary(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - // 只为 bash 注册 summarizer,read_file 没有 - got := microCompactMessagesWithPolicies( - messages, - stubMicroCompactPolicySource{}, - 2, - stubMicroCompactSummarizerSource{ - "bash": func(content string, metadata map[string]string, isError bool) string { - return "[summary] bash: " + content - }, - }, - nil, - ) - - summary := renderDisplayParts(got[2].Parts) - if summary == microCompactClearedMessage { - t.Fatalf("expected fallback summary for read_file without summarizer, got cleared placeholder") - } - if !strings.Contains(summary, "[summary] filesystem_read_file") { - t.Fatalf("expected fallback summary to include tool name, got %q", summary) - } -} - -// TestMicroCompactMixedSpanWithSummarizer 验证混合工具 span 中部分有摘要、部分清除。 -func TestMicroCompactMixedSpanWithSummarizer(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "bash", Arguments: "{}"}, - {ID: "call-2", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("bash output")}}, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("read output")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-4", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-4", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest bash")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("reply")}}, - } - - got := microCompactMessagesWithPolicies( - messages, - stubMicroCompactPolicySource{}, - 2, - stubMicroCompactSummarizerSource{ - "bash": func(content string, metadata map[string]string, isError bool) string { - return "[summary] " + content - }, - }, - nil, - ) - - // call-1 bash 在旧 span,有 summarizer,应生成摘要 - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary]") { - t.Fatalf("expected bash summary in old span, got %q", renderDisplayParts(got[2].Parts)) - } - summary := renderDisplayParts(got[3].Parts) - if summary == microCompactClearedMessage { - t.Fatalf("expected read_file fallback summary in old span, got cleared placeholder") - } - if !strings.Contains(summary, "[summary] filesystem_read_file") { - t.Fatalf("expected read_file fallback summary to include tool name, got %q", summary) - } -} - -// TestMicroCompactSummarizerReturnsEmptyFallsBackToSummary 验证 summarizer 返回空字符串时回退到通用摘要。 -func TestMicroCompactSummarizerReturnsEmptyFallsBackToSummary(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("middle result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got := microCompactMessagesWithPolicies( - messages, - stubMicroCompactPolicySource{}, - 2, - stubMicroCompactSummarizerSource{ - "bash": func(content string, metadata map[string]string, isError bool) string { - return "" // 返回空 - }, - }, - nil, - ) - - summary := renderDisplayParts(got[2].Parts) - if summary == microCompactClearedMessage { - t.Fatalf("expected fallback summary when summarizer returns empty, got cleared placeholder") - } - if !strings.Contains(summary, "[summary] bash") { - t.Fatalf("expected fallback summary to include tool name, got %q", summary) - } -} - -// TestSummarizeOrClearWithNilSummarizers 验证 nil summarizers 回退到清除。 -func TestSummarizeOrClearWithNilSummarizers(t *testing.T) { - t.Parallel() - - got := summarizeOrClear( - providertypes.Message{Parts: []providertypes.ContentPart{providertypes.NewTextPart("test")}}, - "test", - nil, - nil, - ) - if got != microCompactClearedMessage { - t.Fatalf("expected cleared message for nil summarizers, got %q", got) - } -} - -func TestSummarizeOrClearFallsBackWithoutRegisteredSummarizer(t *testing.T) { - t.Parallel() - - got := summarizeOrClear( - providertypes.Message{ToolCallID: "call-1"}, - "first line\nsecond line", - map[string]string{"call-1": "mcp.github.issue"}, - nil, - ) - if got == microCompactClearedMessage { - t.Fatalf("expected fallback summary for MCP tool, got cleared placeholder") - } - if !strings.Contains(got, "[summary] mcp.github.issue") { - t.Fatalf("expected MCP tool name in fallback summary, got %q", got) - } - if !strings.Contains(got, "lines=2") { - t.Fatalf("expected line count in fallback summary, got %q", got) - } -} - -// TestSummarizeOrClearWithToolNamesLookup 验证 toolNames map 查找工具名。 -func TestSummarizeOrClearWithToolNamesLookup(t *testing.T) { - t.Parallel() - - t.Run("found", func(t *testing.T) { - toolNames := map[string]string{"call-2": "filesystem_read_file"} - got := summarizeOrClear( - providertypes.Message{ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("content")}}, - "content", - toolNames, - stubMicroCompactSummarizerSource{ - "filesystem_read_file": func(content string, metadata map[string]string, isError bool) string { - return "[summary] " + content - }, - }, - ) - if !strings.Contains(got, "[summary]") { - t.Fatalf("expected summary, got %q", got) - } - }) - - t.Run("not_found_in_tool_names", func(t *testing.T) { - toolNames := map[string]string{"call-1": "bash"} - got := summarizeOrClear( - providertypes.Message{ToolCallID: "unknown-id", Parts: []providertypes.ContentPart{providertypes.NewTextPart("content")}}, - "content", - toolNames, - stubMicroCompactSummarizerSource{}, - ) - if got != microCompactClearedMessage { - t.Fatalf("expected cleared for unknown tool call id, got %q", got) - } - }) -} - -// TestSummarizeOrClearSanitizesSummary 验证摘要回灌前会执行控制字符净化与长度裁剪。 -func TestSummarizeOrClearSanitizesSummary(t *testing.T) { - t.Parallel() - - raw := strings.Repeat("x", microCompactSummaryMaxRunes+50) + "\n\t\x07" - got := summarizeOrClear( - providertypes.Message{ToolCallID: "call-1"}, - "ignored", - map[string]string{"call-1": "bash"}, - stubMicroCompactSummarizerSource{ - "bash": func(content string, metadata map[string]string, isError bool) string { - return raw - }, - }, - ) - - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected control characters removed, got %q", got) - } - if utf8.RuneCountInString(got) > microCompactSummaryMaxRunes+3 { - t.Fatalf("expected summary capped, got %d runes", utf8.RuneCountInString(got)) - } - if !strings.HasSuffix(got, "...") { - t.Fatalf("expected truncated summary suffix, got %q", got) - } -} - -// TestSummarizeOrClearSanitizationEmptyFallback 验证净化后为空时会回退清理占位。 -func TestSummarizeOrClearSanitizationEmptyFallback(t *testing.T) { - t.Parallel() - - got := summarizeOrClear( - providertypes.Message{ToolCallID: "call-1"}, - "ignored", - map[string]string{"call-1": "bash"}, - stubMicroCompactSummarizerSource{ - "bash": func(content string, metadata map[string]string, isError bool) string { - return "\n\t\x07 " - }, - }, - ) - - if got == microCompactClearedMessage { - t.Fatalf("expected fallback summary when sanitized summary is empty, got cleared placeholder") - } - if !strings.Contains(got, "[summary] bash") { - t.Fatalf("expected fallback summary to include tool name, got %q", got) - } -} - -// TestIsToolCallSpanBoundaries 验证 span 边界异常时返回 false。 -func TestIsToolCallSpanBoundaries(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleAssistant, ToolCalls: []providertypes.ToolCall{{ID: "c1", Name: "bash"}}}, - } - - if isToolCallSpan(messages, internalcompact.MessageSpan{Start: -1, End: 0}) { - t.Fatal("expected false for negative start") - } - if isToolCallSpan(messages, internalcompact.MessageSpan{Start: 2, End: 3}) { - t.Fatal("expected false for out-of-range start") - } -} - -// TestCompactableToolCallIDsEmptyInput 验证空 tool call 输入时返回 nil。 -func TestCompactableToolCallIDsEmptyInput(t *testing.T) { - t.Parallel() - - ids, names := compactableToolCallIDs(nil, nil) - if ids != nil || names != nil { - t.Fatalf("expected nil maps for empty input, got ids=%v names=%v", ids, names) - } -} - -// TestHasCompactableToolMessage 验证工具块可压缩消息探测逻辑。 -func TestHasCompactableToolMessage(t *testing.T) { - t.Parallel() - - span := internalcompact.MessageSpan{Start: 0, End: 3} - ids := map[string]struct{}{"call-1": {}} - - t.Run("true_when_matching_tool_message_exists", func(t *testing.T) { - messages := []providertypes.Message{ - {Role: providertypes.RoleAssistant, ToolCalls: []providertypes.ToolCall{{ID: "call-1", Name: "bash"}}}, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("output")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("u")}}, - } - if !hasCompactableToolMessage(messages, span, ids, nil, nil) { - t.Fatal("expected compactable tool message to be found") - } - }) - - t.Run("false_when_tool_messages_are_not_compactable", func(t *testing.T) { - messages := []providertypes.Message{ - {Role: providertypes.RoleAssistant, ToolCalls: []providertypes.ToolCall{{ID: "call-1", Name: "bash"}}}, - {Role: providertypes.RoleTool, ToolCallID: "call-1", IsError: true, Parts: []providertypes.ContentPart{providertypes.NewTextPart("error")}}, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("other")}}, - } - if hasCompactableToolMessage(messages, span, ids, nil, nil) { - t.Fatal("expected no compactable tool message") - } - }) -} diff --git a/internal/context/microcompact_test.go b/internal/context/microcompact_test.go deleted file mode 100644 index 3d98c7b30..000000000 --- a/internal/context/microcompact_test.go +++ /dev/null @@ -1,618 +0,0 @@ -package context - -import ( - "strings" - "testing" - - providertypes "neo-code/internal/provider/types" - "neo-code/internal/tools" -) - -type stubMicroCompactPolicySource map[string]tools.MicroCompactPolicy - -func (s stubMicroCompactPolicySource) MicroCompactPolicy(name string) tools.MicroCompactPolicy { - if policy, ok := s[name]; ok { - return policy - } - return tools.MicroCompactPolicyCompact -} - -func TestMicroCompactMessagesClearsOlderCompactableToolResults(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current working reply")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 2, nil, nil) - if len(got) != len(messages) { - t.Fatalf("expected message count to stay unchanged, got %d want %d", len(got), len(messages)) - } - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] filesystem_read_file") { - t.Fatalf("expected oldest compactable tool result to fall back to summary, got %q", renderDisplayParts(got[2].Parts)) - } - if renderDisplayParts(got[4].Parts) != "recent bash result" { - t.Fatalf("expected recent compactable tool result to be retained, got %q", renderDisplayParts(got[4].Parts)) - } - if renderDisplayParts(got[6].Parts) != "latest webfetch result" { - t.Fatalf("expected latest compactable tool result to be retained, got %q", renderDisplayParts(got[6].Parts)) - } - if renderDisplayParts(messages[2].Parts) != "old read result" { - t.Fatalf("expected original slice to remain unchanged, got %q", renderDisplayParts(messages[2].Parts)) - } -} - -func TestMicroCompactMessagesHandlesEmptyAndInvalidSpanInputs(t *testing.T) { - t.Parallel() - - if got := microCompactMessages(nil); got != nil { - t.Fatalf("expected nil input to remain nil, got %+v", got) - } - - assistantOnly := []providertypes.Message{ - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "", Name: "bash", Arguments: "{}"}, - }, - }, - } - got := microCompactMessagesWithPolicies(assistantOnly, stubMicroCompactPolicySource{}, 0, nil, nil) - if len(got) != 1 || len(got[0].ToolCalls) != 1 { - t.Fatalf("expected invalid tool call id path to keep message untouched, got %+v", got) - } -} - -func TestMicroCompactMessagesKeepsProtectedTailUntouched(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-0", Name: "filesystem_grep", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-0", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old grep result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("tail bash result")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 2, nil, nil) - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] filesystem_grep") { - t.Fatalf("expected old tool result before protected tail to fall back to summary, got %q", renderDisplayParts(got[2].Parts)) - } - if renderDisplayParts(got[4].Parts) != "recent read result" { - t.Fatalf("expected recent tool result before protected tail to remain, got %q", renderDisplayParts(got[4].Parts)) - } - if renderDisplayParts(got[6].Parts) != "recent bash result" { - t.Fatalf("expected second recent tool result before protected tail to remain, got %q", renderDisplayParts(got[6].Parts)) - } - if renderDisplayParts(got[9].Parts) != "tail bash result" { - t.Fatalf("expected protected tail tool result to remain, got %q", renderDisplayParts(got[9].Parts)) - } -} - -func TestMicroCompactMessagesKeepsPreservedToolsErrorsAndOrphans(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "custom_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("custom result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "filesystem_edit", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("edit failed")}, IsError: true}, - {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "filesystem_write_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart(microCompactClearedMessage)}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-4", Name: "filesystem_grep", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-4", Parts: []providertypes.ContentPart{providertypes.NewTextPart("")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{ - "custom_tool": tools.MicroCompactPolicyPreserveHistory, - }, 2, nil, nil) - if renderDisplayParts(got[1].Parts) != "custom result" { - t.Fatalf("expected preserved tool result to remain, got %q", renderDisplayParts(got[1].Parts)) - } - if renderDisplayParts(got[3].Parts) != "edit failed" { - t.Fatalf("expected error tool result to remain, got %q", renderDisplayParts(got[3].Parts)) - } - if renderDisplayParts(got[4].Parts) != "orphan result" { - t.Fatalf("expected orphan tool result to remain, got %q", renderDisplayParts(got[4].Parts)) - } - if renderDisplayParts(got[6].Parts) != microCompactClearedMessage { - t.Fatalf("expected already cleared content to remain unchanged, got %q", renderDisplayParts(got[6].Parts)) - } - if renderDisplayParts(got[8].Parts) != "" { - t.Fatalf("expected empty tool result to remain empty, got %q", renderDisplayParts(got[8].Parts)) - } -} - -func TestMicroCompactMessagesClearsOnlyNonPreservedResultsInMixedToolSpan(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - {ID: "call-2", Name: "custom_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("read result")}}, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("custom result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-4", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-4", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{ - "custom_tool": tools.MicroCompactPolicyPreserveHistory, - }, 2, nil, nil) - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] filesystem_read_file") { - t.Fatalf("expected default compactable tool result to fall back to summary, got %q", renderDisplayParts(got[2].Parts)) - } - if renderDisplayParts(got[3].Parts) != "custom result" { - t.Fatalf("expected preserved tool result in mixed span to remain, got %q", renderDisplayParts(got[3].Parts)) - } - if len(got[1].ToolCalls) != 2 { - t.Fatalf("expected assistant tool call metadata to remain intact, got %+v", got[1].ToolCalls) - } -} - -func TestMicroCompactMessagesTreatsNewToolsAsCompactableByDefault(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "repo_search", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("old repo search result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 2, nil, nil) - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] repo_search") { - t.Fatalf("expected new tool result to be compacted into fallback summary by default, got %q", renderDisplayParts(got[2].Parts)) - } -} - -func TestMicroCompactMessagesPreservesSpawnSubAgentHistory(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: tools.ToolNameSpawnSubAgent, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("spawned analysis")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: tools.ToolNameBash, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: tools.ToolNameWebFetch, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{ - tools.ToolNameSpawnSubAgent: tools.MicroCompactPolicyPreserveHistory, - }, 1, nil, nil) - if renderDisplayParts(got[2].Parts) != "spawned analysis" { - t.Fatalf("expected spawn_subagent history to be preserved, got %q", renderDisplayParts(got[2].Parts)) - } -} - -func TestMicroCompactMessagesPreservesCodebaseReadHistory(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: tools.ToolNameCodebaseRead, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("path: main.go\n\npackage main")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: tools.ToolNameBash, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: tools.ToolNameWebFetch, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{ - tools.ToolNameCodebaseRead: tools.MicroCompactPolicyPreserveHistory, - }, 2, nil, nil) - if renderDisplayParts(got[2].Parts) != "path: main.go\n\npackage main" { - t.Fatalf("expected codebase_read history to stay visible, got %q", renderDisplayParts(got[2].Parts)) - } -} - -func TestMicroCompactMessagesSkipsEmptyRecentSpansWhenCountingRetainedBudget(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_read_file", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("older read result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "filesystem_grep", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("middle grep result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "filesystem_edit", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("near edit result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-4", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-4", Parts: []providertypes.ContentPart{providertypes.NewTextPart("")}, IsError: true}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-5", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-5", Parts: []providertypes.ContentPart{providertypes.NewTextPart("")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 2, nil, nil) - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] filesystem_read_file") { - t.Fatalf("expected oldest valid tool result to fall back to summary, got %q", renderDisplayParts(got[2].Parts)) - } - if renderDisplayParts(got[4].Parts) != "middle grep result" { - t.Fatalf("expected middle valid tool result to remain, got %q", renderDisplayParts(got[4].Parts)) - } - if renderDisplayParts(got[6].Parts) != "near edit result" { - t.Fatalf("expected nearer valid tool result to remain, got %q", renderDisplayParts(got[6].Parts)) - } - if renderDisplayParts(got[8].Parts) != "" { - t.Fatalf("expected error/empty tool result to remain unchanged, got %q", renderDisplayParts(got[8].Parts)) - } - if renderDisplayParts(got[10].Parts) != "" { - t.Fatalf("expected empty recent tool result to remain unchanged, got %q", renderDisplayParts(got[10].Parts)) - } -} - -func TestMicroCompactMessagesSkipsToolMessagesWhenCompactableIDsMissing(t *testing.T) { - t.Parallel() - - messages := []providertypes.Message{ - {Role: providertypes.RoleTool, ToolCallID: "orphan", Parts: []providertypes.ContentPart{providertypes.NewTextPart("orphan result")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 0, nil, nil) - if renderDisplayParts(got[0].Parts) != "orphan result" { - t.Fatalf("expected orphan tool result to remain, got %q", renderDisplayParts(got[0].Parts)) - } -} - -// TestMicroCompactPinnedResultNotCompacted 验证被 pin checker 钉住的工具结果不会被压缩。 -func TestMicroCompactPinnedResultNotCompacted(t *testing.T) { - t.Parallel() - - stubPin := stubMicroCompactPinChecker{ - "filesystem_write_file": map[string]bool{"README.md": true}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_write_file", Arguments: `{"path":"README.md"}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("README content")}, ToolMetadata: map[string]string{"path": "/project/README.md"}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("current reply")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 1, nil, stubPin) - if renderDisplayParts(got[2].Parts) != "README content" { - t.Fatalf("expected pinned README result to be preserved, got %q", renderDisplayParts(got[2].Parts)) - } -} - -// TestMicroCompactMixedPinnedAndNonPinned 验证同一 span 中钉住和非钉住结果混合时仅压缩非钉住的。 -func TestMicroCompactMixedPinnedAndNonPinned(t *testing.T) { - t.Parallel() - - stubPin := stubMicroCompactPinChecker{ - "filesystem_write_file": map[string]bool{"README.md": true}, - } - - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "filesystem_write_file", Arguments: `{"path":"README.md"}`}, - {ID: "call-2", Name: "filesystem_write_file", Arguments: `{"path":"main.go"}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("README content")}, ToolMetadata: map[string]string{"path": "/project/README.md"}}, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("main.go content")}, ToolMetadata: map[string]string{"path": "/project/main.go"}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("reply")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 1, nil, stubPin) - if renderDisplayParts(got[2].Parts) != "README content" { - t.Fatalf("expected pinned README result preserved, got %q", renderDisplayParts(got[2].Parts)) - } - if !strings.Contains(renderDisplayParts(got[3].Parts), "[summary] filesystem_write_file") { - t.Fatalf("expected non-pinned main.go result to fall back to summary, got %q", renderDisplayParts(got[3].Parts)) - } -} - -func TestMicroCompactPinsCopyAndMoveUsingPersistedMetadataPaths(t *testing.T) { - t.Parallel() - - pinChecker := NewDefaultPinChecker() - - copyMessages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "copy-call", Name: tools.ToolNameFilesystemCopyFile, Arguments: `{"source_path":"main.go","destination_path":"go.mod"}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "copy-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}, ToolMetadata: map[string]string{ - "tool_name": tools.ToolNameFilesystemCopyFile, - "source_path": "/project/main.go", - "destination_path": "/project/go.mod", - }}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "recent-call", Name: tools.ToolNameBash, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "recent-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - copyGot := microCompactMessagesWithPolicies(copyMessages, stubMicroCompactPolicySource{}, 1, nil, pinChecker) - if renderDisplayParts(copyGot[2].Parts) != "ok" { - t.Fatalf("expected copy_file result touching go.mod to stay pinned, got %q", renderDisplayParts(copyGot[2].Parts)) - } - - moveMessages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "move-call", Name: tools.ToolNameFilesystemMoveFile, Arguments: `{"source_path":"package.json","destination_path":"package.backup.json"}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "move-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}, ToolMetadata: map[string]string{ - "tool_name": tools.ToolNameFilesystemMoveFile, - "source_path": "/project/package.json", - "destination_path": "/project/package.backup.json", - }}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "recent-call", Name: tools.ToolNameBash, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "recent-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - moveGot := microCompactMessagesWithPolicies(moveMessages, stubMicroCompactPolicySource{}, 1, nil, pinChecker) - if renderDisplayParts(moveGot[2].Parts) != "ok" { - t.Fatalf("expected move_file result touching package.json to stay pinned, got %q", renderDisplayParts(moveGot[2].Parts)) - } -} - -func TestMicroCompactStillCompactsCopyAndMoveWhenNoKeyFileIsTouched(t *testing.T) { - t.Parallel() - - pinChecker := NewDefaultPinChecker() - messages := []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "move-call", Name: tools.ToolNameFilesystemMoveFile, Arguments: `{"source_path":"main.go","destination_path":"main2.go"}`}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "move-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}, ToolMetadata: map[string]string{ - "tool_name": tools.ToolNameFilesystemMoveFile, - "source_path": "/project/main.go", - "destination_path": "/project/main2.go", - }}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "recent-call", Name: tools.ToolNameBash, Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "recent-call", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}}, - } - - got := microCompactMessagesWithPolicies(messages, stubMicroCompactPolicySource{}, 1, nil, pinChecker) - if !strings.Contains(renderDisplayParts(got[2].Parts), "[summary] filesystem_move_file") { - t.Fatalf("expected non-key move_file result to still compact into summary, got %q", renderDisplayParts(got[2].Parts)) - } -} - -// stubMicroCompactPinChecker 实现 MicroCompactPinChecker,用于测试。 -type stubMicroCompactPinChecker map[string]map[string]bool - -func (s stubMicroCompactPinChecker) ShouldPin(toolName string, metadata map[string]string) bool { - paths, ok := s[toolName] - if !ok { - return false - } - path := metadata["path"] - if path == "" { - path = metadata["relative_path"] - } - for pinnedPath, shouldPin := range paths { - if shouldPin && strings.Contains(path, pinnedPath) { - return true - } - } - return false -} diff --git a/internal/context/pin_checker.go b/internal/context/pin_checker.go deleted file mode 100644 index 550c3835f..000000000 --- a/internal/context/pin_checker.go +++ /dev/null @@ -1,94 +0,0 @@ -package context - -import ( - "path/filepath" - "strings" - - "neo-code/internal/tools" -) - -// defaultPinPatterns 列出关键产物文件的 basename glob 模式,匹配的工具结果不参与微压缩。 -var defaultPinPatterns = []string{ - "README*", - "*.spec.*", - "*.schema.*", - "docker-compose*", - "*migration*", - "Makefile", - "go.mod", - "package.json", -} - -// defaultPinToolNames 约束默认 pin 仅对明确修改文件内容的工具生效,避免扩散到读取类或自定义工具。 -var defaultPinToolNames = map[string]struct{}{ - tools.ToolNameFilesystemWriteFile: {}, - tools.ToolNameFilesystemEdit: {}, - tools.ToolNameFilesystemCopyFile: {}, - tools.ToolNameFilesystemMoveFile: {}, -} - -// pinChecker 基于文件路径 glob 模式判断工具结果是否应钉住。 -type pinChecker struct { - patterns []string -} - -// NewDefaultPinChecker 返回使用默认钉住模式的 PinChecker。 -func NewDefaultPinChecker() MicroCompactPinChecker { - return &pinChecker{patterns: defaultPinPatterns} -} - -// ShouldPin 判断工具结果是否应钉住:从 metadata 中提取文件路径,对 basename 匹配 glob 模式。 -func (p *pinChecker) ShouldPin(toolName string, metadata map[string]string) bool { - if len(metadata) == 0 { - return false - } - if !toolSupportsPinnedRetention(toolName) { - return false - } - - for _, path := range candidatePinPaths(metadata) { - base := filepath.Base(path) - for _, pattern := range p.patterns { - if matched, _ := filepath.Match(pattern, base); matched { - return true - } - } - } - return false -} - -// candidatePinPaths 按稳定顺序提取可参与 pin 判断的文件路径字段。 -func candidatePinPaths(metadata map[string]string) []string { - keys := []string{"relative_path", "path", "source_path", "destination_path"} - paths := make([]string, 0, len(keys)) - for _, key := range keys { - path := strings.TrimSpace(metadata[key]) - if path == "" { - continue - } - paths = append(paths, path) - } - if len(paths) == 0 { - return nil - } - return paths -} - -// toolSupportsPinnedRetention 判断工具是否允许参与默认 pin 策略,避免非文件修改类工具扩大保留范围。 -func toolSupportsPinnedRetention(toolName string) bool { - _, ok := defaultPinToolNames[strings.TrimSpace(toolName)] - return ok -} - -// isPinnedToolMessage 检查工具消息是否被 pin checker 钉住,被钉住的消息不参与微压缩。 -func isPinnedToolMessage(toolName string, metadata map[string]string, checker MicroCompactPinChecker) bool { - if checker == nil || len(metadata) == 0 { - return false - } - return checker.ShouldPin(toolName, metadata) -} - -// toolNameFromCallID 在 toolNames 映射中查找 callID 对应的工具名。 -func toolNameFromCallID(callID string, toolNames map[string]string) string { - return toolNames[strings.TrimSpace(callID)] -} diff --git a/internal/context/pin_checker_test.go b/internal/context/pin_checker_test.go deleted file mode 100644 index 7bc9cc5ae..000000000 --- a/internal/context/pin_checker_test.go +++ /dev/null @@ -1,156 +0,0 @@ -package context - -import ( - "testing" -) - -func TestDefaultPinCheckerMatchesKeyArtifacts(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - tests := []struct { - toolName string - path string - expected bool - }{ - {toolName: "filesystem_write_file", path: "README.md", expected: true}, - {toolName: "filesystem_write_file", path: "README.txt", expected: true}, - {toolName: "filesystem_write_file", path: "readme.md", expected: false}, // glob 区分大小写 - {toolName: "filesystem_write_file", path: "api.spec.yaml", expected: true}, - {toolName: "filesystem_write_file", path: "design.spec.md", expected: true}, - {toolName: "filesystem_write_file", path: "db.schema.json", expected: true}, - {toolName: "filesystem_write_file", path: "schema.sql", expected: false}, // *schema.* 需要两端有内容 - {toolName: "filesystem_write_file", path: "db.schema.sql", expected: true}, - {toolName: "filesystem_write_file", path: "docker-compose.yml", expected: true}, - {toolName: "filesystem_write_file", path: "docker-compose.yaml", expected: true}, - {toolName: "filesystem_write_file", path: ".env", expected: false}, - {toolName: "filesystem_write_file", path: ".env.local", expected: false}, - {toolName: "filesystem_write_file", path: ".env.example", expected: false}, - {toolName: "filesystem_write_file", path: "01_migration.sql", expected: true}, - {toolName: "filesystem_write_file", path: "migration.rb", expected: true}, - {toolName: "filesystem_write_file", path: "create_users_migration.sql", expected: true}, - {toolName: "filesystem_write_file", path: "Makefile", expected: true}, - {toolName: "filesystem_write_file", path: "go.mod", expected: true}, - {toolName: "filesystem_write_file", path: "package.json", expected: true}, - {toolName: "filesystem_write_file", path: "main.go", expected: false}, - {toolName: "filesystem_write_file", path: "app.tsx", expected: false}, - {toolName: "filesystem_write_file", path: "index.js", expected: false}, - {toolName: "filesystem_write_file", path: "utils.py", expected: false}, - {toolName: "filesystem_write_file", path: "style.css", expected: false}, - {toolName: "filesystem_edit", path: "README.md", expected: true}, - {toolName: "filesystem_copy_file", path: "go.mod", expected: true}, - {toolName: "filesystem_move_file", path: "package.json", expected: true}, - {toolName: "filesystem_read_file", path: "README.md", expected: false}, - {toolName: "bash", path: "README.md", expected: false}, - } - - for _, tt := range tests { - got := checker.ShouldPin(tt.toolName, map[string]string{"path": "/project/" + tt.path}) - if got != tt.expected { - t.Errorf("ShouldPin(tool=%q, path=%q) = %v, want %v", tt.toolName, tt.path, got, tt.expected) - } - } -} - -func TestDefaultPinCheckerUsesRelativePath(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - // relative_path 优先于 path - got := checker.ShouldPin("filesystem_write_file", map[string]string{ - "relative_path": "api.spec.yaml", - }) - if !got { - t.Error("expected relative_path match for api.spec.yaml") - } -} - -func TestDefaultPinCheckerFallsBackToPath(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - got := checker.ShouldPin("filesystem_write_file", map[string]string{ - "path": "/project/README.md", - }) - if !got { - t.Error("expected path fallback match for README.md") - } -} - -func TestDefaultPinCheckerSupportsCopyAndMoveMetadataFields(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - copyPinned := checker.ShouldPin("filesystem_copy_file", map[string]string{ - "destination_path": "/project/go.mod", - }) - if !copyPinned { - t.Error("expected destination_path match for copy_file go.mod") - } - - movePinned := checker.ShouldPin("filesystem_move_file", map[string]string{ - "source_path": "/project/package.json", - }) - if !movePinned { - t.Error("expected source_path match for move_file package.json") - } - - notPinned := checker.ShouldPin("filesystem_move_file", map[string]string{ - "source_path": "/project/main.go", - "destination_path": "/project/main2.go", - }) - if notPinned { - t.Error("expected non-key source/destination paths to remain unpinned") - } -} - -func TestDefaultPinCheckerNoPathReturnsFalse(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - got := checker.ShouldPin("filesystem_write_file", map[string]string{"workdir": "/tmp"}) - if got { - t.Error("expected false when no path in metadata") - } -} - -func TestDefaultPinCheckerEmptyMetadataReturnsFalse(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - got := checker.ShouldPin("filesystem_write_file", nil) - if got { - t.Error("expected false for nil metadata") - } -} - -func TestDefaultPinCheckerBashToolNotPinned(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - // bash 工具元信息只有 workdir,不应被钉住 - got := checker.ShouldPin("bash", map[string]string{"workdir": "/project"}) - if got { - t.Error("expected bash tool with workdir only to not be pinned") - } -} - -func TestDefaultPinCheckerIgnoresPathMetadataForUnsupportedTool(t *testing.T) { - t.Parallel() - - checker := NewDefaultPinChecker() - - got := checker.ShouldPin("filesystem_read_file", map[string]string{ - "path": "/project/README.md", - }) - if got { - t.Error("expected unsupported tool with path metadata to not be pinned") - } -} diff --git a/internal/context/projection.go b/internal/context/projection.go index 9db200dfa..bcdfbfb8d 100644 --- a/internal/context/projection.go +++ b/internal/context/projection.go @@ -14,7 +14,36 @@ const ( recentWindowToolContentTailChars = 300 ) -const truncatedExcerptMarker = "\n...[truncated]...\n" +const ( + truncatedExcerptMarker = "\n...[truncated]...\n" + contentTruncatedForModelContext = "[content truncated for model context]" +) + +// cloneContextMessages 深拷贝消息切片,避免读时投影污染 runtime 持有的原始会话消息。 +func cloneContextMessages(messages []providertypes.Message) []providertypes.Message { + if len(messages) == 0 { + return nil + } + + cloned := make([]providertypes.Message, 0, len(messages)) + for _, message := range messages { + cloned = append(cloned, cloneSingleMessage(message)) + } + return cloned +} + +// cloneSingleMessage 深拷贝单条消息,隔离 ToolCalls 和 ToolMetadata 的底层引用。 +func cloneSingleMessage(msg providertypes.Message) providertypes.Message { + next := msg + next.ToolCalls = append([]providertypes.ToolCall(nil), msg.ToolCalls...) + if len(msg.ToolMetadata) > 0 { + next.ToolMetadata = make(map[string]string, len(msg.ToolMetadata)) + for key, value := range msg.ToolMetadata { + next.ToolMetadata[key] = value + } + } + return next +} // ProjectToolMessagesForModel 原地投影 tool 消息,复用主链路对模型可见的只读格式化规则。 func ProjectToolMessagesForModel(messages []providertypes.Message) []providertypes.Message { @@ -29,6 +58,11 @@ func ProjectToolMessagesForModel(messages []providertypes.Message) []providertyp return messages } +// projectReadTimeMessagesForModel 构造 provider 读取路径的只读消息投影,避免改写会话原始消息并限制工具输出体积。 +func projectReadTimeMessagesForModel(messages []providertypes.Message) []providertypes.Message { + return sanitizeProjectedToolMessages(ProjectToolMessagesForModel(cloneContextMessages(messages))) +} + // BuildRecentMessagesForModel 从会话尾部构造 provider-safe 的最近消息窗口,避免保留非法 tool call 片段。 func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) []providertypes.Message { if len(messages) == 0 || limit <= 0 { @@ -79,7 +113,7 @@ func BuildRecentMessagesForModel(messages []providertypes.Message, limit int) [] return nil } - return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) + return sanitizeProjectedToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) } // BuildMemoExtractionMessagesForModel 构造完整 run 的 provider-safe 记忆提取上下文。 @@ -115,7 +149,7 @@ func BuildMemoExtractionMessagesForModel(messages []providertypes.Message) []pro return nil } - return sanitizeRecentWindowToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) + return sanitizeProjectedToolMessages(ProjectToolMessagesForModel(cloneContextMessages(selected))) } // matchedToolCallSpan 返回 assistant tool call 与其完整 tool 响应组成的合法窗口下标集合。 @@ -184,9 +218,6 @@ func isInjectableToolMessage(message providertypes.Message) bool { return false } content := strings.TrimSpace(renderDisplayParts(message.Parts)) - if content == microCompactClearedMessage { - return false - } return content != "" || len(message.ToolMetadata) > 0 } @@ -205,8 +236,8 @@ func recentWindowMessageBudget(limit int) int { return budget } -// sanitizeRecentWindowToolMessages 缩减 tool 消息内容,降低 memo 提取链路对原始工具输出的暴露面。 -func sanitizeRecentWindowToolMessages(messages []providertypes.Message) []providertypes.Message { +// sanitizeProjectedToolMessages 缩减 tool 消息内容,降低模型上下文对原始工具输出的暴露面。 +func sanitizeProjectedToolMessages(messages []providertypes.Message) []providertypes.Message { for index := range messages { message := messages[index] if message.Role != providertypes.RoleTool { @@ -235,7 +266,7 @@ func sanitizeProjectedToolContent(content string) string { limited, truncated := sanitizeToolExcerpt(body) lines := []string{prefix, "content_excerpt:", limited} if truncated { - lines = append(lines, "[content truncated for memo extraction]") + lines = append(lines, contentTruncatedForModelContext) } return strings.Join(lines, "\n") } @@ -253,11 +284,11 @@ func sanitizeRawToolContent(content string) string { return strings.Join([]string{ "content_excerpt:", limited, - "[content truncated for memo extraction]", + contentTruncatedForModelContext, }, "\n") } -// sanitizeToolExcerpt 保留工具输出的头尾窗口,避免 memo 提取遗漏尾部关键错误。 +// sanitizeToolExcerpt 保留工具输出的头尾窗口,避免模型上下文遗漏尾部关键错误。 func sanitizeToolExcerpt(text string) (string, bool) { total := utf8.RuneCountInString(text) limit := recentWindowToolContentHeadChars + recentWindowToolContentTailChars diff --git a/internal/context/projection_test.go b/internal/context/projection_test.go index 8e78a870b..b611e83fb 100644 --- a/internal/context/projection_test.go +++ b/internal/context/projection_test.go @@ -24,12 +24,6 @@ func TestProjectToolMessagesForModelSkipsMessagesThatCannotBeProjected(t *testin Parts: []providertypes.ContentPart{providertypes.NewTextPart(" ")}, ToolMetadata: map[string]string{"tool_name": "bash"}, }, - { - Role: providertypes.RoleTool, - ToolCallID: "call-3", - Parts: []providertypes.ContentPart{providertypes.NewTextPart(microCompactClearedMessage)}, - ToolMetadata: map[string]string{"tool_name": "bash"}, - }, { Role: providertypes.RoleTool, ToolCallID: "call-4", @@ -48,10 +42,7 @@ func TestProjectToolMessagesForModelSkipsMessagesThatCannotBeProjected(t *testin if !strings.Contains(renderDisplayParts(projected[2].Parts), "tool result") || projected[2].ToolMetadata != nil { t.Fatalf("metadata-only tool message should be projected, got %+v", projected[2]) } - if renderDisplayParts(projected[3].Parts) != microCompactClearedMessage || projected[3].ToolMetadata == nil { - t.Fatalf("cleared tool content should not be projected, got %+v", projected[3]) - } - if !strings.Contains(renderDisplayParts(projected[4].Parts), "tool result") || projected[4].ToolMetadata != nil { + if !strings.Contains(renderDisplayParts(projected[3].Parts), "tool result") || projected[3].ToolMetadata != nil { t.Fatalf("valid tool message should be projected, got %+v", projected[4]) } } @@ -244,11 +235,55 @@ func TestSanitizeProjectedToolContent(t *testing.T) { if !strings.Contains(sanitized, "TAIL-MARKER") { t.Fatalf("expected tail content to be preserved, got %q", sanitized) } - if !strings.Contains(sanitized, "[content truncated for memo extraction]") { + if !strings.Contains(sanitized, contentTruncatedForModelContext) { t.Fatalf("expected truncation marker, got %q", sanitized) } } +func TestProjectReadTimeMessagesForModelBoundaries(t *testing.T) { + t.Parallel() + + if got := projectReadTimeMessagesForModel(nil); got != nil { + t.Fatalf("expected nil for empty read-time projection, got %+v", got) + } + + messages := []providertypes.Message{ + { + Role: providertypes.RoleAssistant, + ToolCalls: []providertypes.ToolCall{ + {ID: "call-1", Name: "filesystem_read_file", Arguments: `{"path":"README.md"}`}, + }, + }, + { + Role: providertypes.RoleTool, + ToolCallID: "call-1", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("short result")}, + ToolMetadata: map[string]string{"tool_name": "filesystem_read_file", "path": "README.md"}, + }, + } + + projected := projectReadTimeMessagesForModel(messages) + if len(projected) != 2 { + t.Fatalf("len(projected) = %d, want 2", len(projected)) + } + projectedText := renderDisplayParts(projected[1].Parts) + if !strings.Contains(projectedText, "content_excerpt:") { + t.Fatalf("expected read-time tool content to use excerpt marker, got %q", projectedText) + } + if !strings.Contains(projectedText, "short result") { + t.Fatalf("expected short result to remain visible, got %q", projectedText) + } + if strings.Contains(projectedText, contentTruncatedForModelContext) { + t.Fatalf("did not expect truncation marker for short content, got %q", projectedText) + } + if projected[1].ToolMetadata != nil { + t.Fatalf("expected projected metadata to be cleared, got %#v", projected[1].ToolMetadata) + } + if renderDisplayParts(messages[1].Parts) != "short result" || messages[1].ToolMetadata == nil { + t.Fatalf("expected source messages to remain unchanged, got %+v", messages[1]) + } +} + func TestMatchedToolCallSpanRejectsInvalidAssistantStates(t *testing.T) { t.Parallel() @@ -370,11 +405,6 @@ func TestIsInjectableToolMessage(t *testing.T) { message: providertypes.Message{Role: providertypes.RoleTool, Parts: []providertypes.ContentPart{providertypes.NewTextPart(" ")}, ToolMetadata: map[string]string{"tool_name": "bash"}}, want: true, }, - { - name: "cleared", - message: providertypes.Message{Role: providertypes.RoleTool, Parts: []providertypes.ContentPart{providertypes.NewTextPart(microCompactClearedMessage)}}, - want: false, - }, { name: "valid", message: providertypes.Message{Role: providertypes.RoleTool, Parts: []providertypes.ContentPart{providertypes.NewTextPart("ok")}}, @@ -407,7 +437,20 @@ func TestSanitizeProjectedToolContentFallsBackForRawPayload(t *testing.T) { if !strings.Contains(sanitized, "...[truncated]...") { t.Fatalf("expected middle truncation marker, got %q", sanitized) } - if !strings.Contains(sanitized, "[content truncated for memo extraction]") { + if !strings.Contains(sanitized, contentTruncatedForModelContext) { t.Fatalf("expected truncation marker, got %q", sanitized) } } + +func TestSanitizeProjectedToolContentRawPayloadBoundaries(t *testing.T) { + t.Parallel() + + if got := sanitizeProjectedToolContent(" "); got != "" { + t.Fatalf("expected blank raw content to sanitize to empty string, got %q", got) + } + + const shortRaw = "short raw payload" + if got := sanitizeProjectedToolContent(shortRaw); got != shortRaw { + t.Fatalf("expected short raw payload unchanged, got %q", got) + } +} diff --git a/internal/context/source_plan_mode.go b/internal/context/source_plan_mode.go index 59ab26d67..d86c10986 100644 --- a/internal/context/source_plan_mode.go +++ b/internal/context/source_plan_mode.go @@ -34,7 +34,7 @@ func (planModeContextSource) Sections(ctx context.Context, input BuildInput) ([] if stage == "plan" { noPlanHint := promptSection{ Title: "Current Plan", - Content: "status: none\n\nNo current plan exists. You must create one by outputting a `plan_spec` + `summary_candidate` JSON before this turn ends.", + Content: "status: none\n\nNo current plan exists. You must create one before this turn ends by outputting a visible Markdown plan, followed by one compact `plan_spec` + `summary_candidate` JSON object inside an HTML comment.", } sections = append(sections, noPlanHint) } diff --git a/internal/context/source_plan_mode_test.go b/internal/context/source_plan_mode_test.go index 25cae39d3..76e19dd49 100644 --- a/internal/context/source_plan_mode_test.go +++ b/internal/context/source_plan_mode_test.go @@ -50,6 +50,37 @@ func TestPlanModeSectionsReturnsWithoutPlanWhenNil(t *testing.T) { } } +func TestPlanModeSectionsNoCurrentPlanUsesHTMLCommentContract(t *testing.T) { + t.Parallel() + + source := planModeContextSource{} + sections, err := source.Sections(context.Background(), BuildInput{ + AgentMode: agentsession.AgentModePlan, + PlanStage: "plan", + CurrentPlan: nil, + }) + if err != nil { + t.Fatalf("Sections() error = %v", err) + } + var currentPlanContent string + for _, section := range sections { + if section.Title == "Current Plan" { + currentPlanContent = section.Content + break + } + } + if currentPlanContent == "" { + t.Fatal("expected Current Plan section when plan stage has no current plan") + } + if !strings.Contains(currentPlanContent, "visible Markdown plan") || + !strings.Contains(currentPlanContent, "inside an HTML comment") { + t.Fatalf("Current Plan hint = %q, want Markdown plus HTML comment JSON contract", currentPlanContent) + } + if strings.Contains(currentPlanContent, "outputting a `plan_spec` + `summary_candidate` JSON") { + t.Fatalf("Current Plan hint should not use old JSON-only wording: %q", currentPlanContent) + } +} + func TestPlanModeSectionsContextError(t *testing.T) { t.Parallel() diff --git a/internal/context/source_todos.go b/internal/context/source_todos.go index 2fa8db1cb..a677245e5 100644 --- a/internal/context/source_todos.go +++ b/internal/context/source_todos.go @@ -27,7 +27,12 @@ func (todosSource) Sections(ctx context.Context, input BuildInput) ([]promptSect return nil, err } if len(input.Todos) == 0 { - return nil, nil + return []promptSection{ + { + Title: "Todo State", + Content: "None", + }, + }, nil } active := make([]agentsession.TodoItem, 0, len(input.Todos)) @@ -37,7 +42,12 @@ func (todosSource) Sections(ctx context.Context, input BuildInput) ([]promptSect } } if len(active) == 0 { - return nil, nil + return []promptSection{ + { + Title: "Todo State", + Content: "None", + }, + }, nil } sort.SliceStable(active, func(i, j int) bool { diff --git a/internal/context/source_todos_test.go b/internal/context/source_todos_test.go index a0bf7b66a..e4708ac15 100644 --- a/internal/context/source_todos_test.go +++ b/internal/context/source_todos_test.go @@ -74,8 +74,8 @@ func TestTodosSourceSectionsBoundaries(t *testing.T) { if err != nil { t.Fatalf("Sections() error = %v", err) } - if sections != nil { - t.Fatalf("Sections() = %+v, want nil", sections) + if len(sections) != 1 || sections[0].Content != "None" { + t.Fatalf("Sections() = %+v, want single section with 'None'", sections) } ctx, cancel := stdcontext.WithCancel(stdcontext.Background()) @@ -100,8 +100,8 @@ func TestTodosSourceSectionsAllTerminal(t *testing.T) { if err != nil { t.Fatalf("Sections() error = %v", err) } - if sections != nil { - t.Fatalf("Sections() = %+v, want nil for all terminal todos", sections) + if len(sections) != 1 || sections[0].Content != "None" { + t.Fatalf("Sections() = %+v, want single section with 'None' for all terminal todos", sections) } } diff --git a/internal/context/types.go b/internal/context/types.go index a0296d266..a5106324f 100644 --- a/internal/context/types.go +++ b/internal/context/types.go @@ -7,7 +7,6 @@ import ( "neo-code/internal/repository" agentsession "neo-code/internal/session" "neo-code/internal/skills" - "neo-code/internal/tools" ) // Builder builds the provider-facing context for a single model round. @@ -73,32 +72,7 @@ type RepositoryRetrievalSection struct { Query string } -// MicroCompactPolicySource 定义 context 读取工具 micro compact 策略的最小依赖。 -type MicroCompactPolicySource interface { - MicroCompactPolicy(name string) tools.MicroCompactPolicy -} - -// MicroCompactSummarizerSource 定义 context 查找按工具内容摘要器的最小依赖。 -type MicroCompactSummarizerSource interface { - MicroCompactSummarizer(name string) tools.ContentSummarizer -} - -// MicroCompactPinChecker 定义上下文层判断单个工具结果是否应钉住(不参与微压缩)的接口。 -type MicroCompactPinChecker interface { - ShouldPin(toolName string, metadata map[string]string) bool -} - -// MicroCompactConfig 聚合微压缩所需的三个依赖源,简化 Builder 构造参数。 -// 三个子接口仍各自遵循接口隔离原则;MicroCompactConfig 仅作为构造时的参数打包容器。 -type MicroCompactConfig struct { - Policies MicroCompactPolicySource - Summarizers MicroCompactSummarizerSource - PinChecker MicroCompactPinChecker -} - -// CompactOptions controls read-time compact behavior inside the context builder. +// CompactOptions controls read-time context behavior inside the context builder. type CompactOptions struct { - DisableMicroCompact bool - MicroCompactRetainedToolSpans int - ReadTimeMaxMessageSpans int + ReadTimeMaxMessageSpans int } diff --git a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go index 70377cc7e..7e54f27a5 100644 --- a/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go +++ b/internal/gateway/adapters/urlscheme/dispatcher_integration_unix_test.go @@ -189,6 +189,20 @@ func (s *urlschemeIntegrationRuntimeStub) CreateSession( return strings.TrimSpace("session-review-integration"), nil } +func (s *urlschemeIntegrationRuntimeStub) SaveSessionAsset( + context.Context, + gateway.SaveSessionAssetInput, +) (gateway.SessionAssetMeta, error) { + return gateway.SessionAssetMeta{}, nil +} + +func (s *urlschemeIntegrationRuntimeStub) OpenSessionAsset( + context.Context, + gateway.OpenSessionAssetInput, +) (gateway.OpenSessionAssetResult, error) { + return gateway.OpenSessionAssetResult{}, nil +} + func (s *urlschemeIntegrationRuntimeStub) ListSessionTodos( context.Context, gateway.ListSessionTodosInput, diff --git a/internal/gateway/bootstrap.go b/internal/gateway/bootstrap.go index 40b378b02..397c591b0 100644 --- a/internal/gateway/bootstrap.go +++ b/internal/gateway/bootstrap.go @@ -523,15 +523,19 @@ func dispatchRunFrameWithSubjectID( ) } if relayExists && relay != nil { - errorCode := "INTERNAL_ERROR" + errorCode := ErrorCodeInternalError.String() errorMessage := "run failed" + stopReason := "" if failedFrame.Error != nil { - if normalizedCode := strings.ToUpper(strings.TrimSpace(failedFrame.Error.Code)); normalizedCode != "" { + if normalizedCode := strings.TrimSpace(failedFrame.Error.Code); normalizedCode != "" { errorCode = normalizedCode } if normalizedMessage := strings.TrimSpace(failedFrame.Error.Message); normalizedMessage != "" { errorMessage = normalizedMessage } + if strings.TrimSpace(failedFrame.Error.Code) == ErrorCodeMaxTurnExceeded.String() { + stopReason = ErrorCodeMaxTurnExceeded.String() + } } fallbackSessionID := strings.TrimSpace(frameSnapshot.SessionID) if fallbackSessionID == "" { @@ -542,14 +546,18 @@ func dispatchRunFrameWithSubjectID( fallbackRunID = strings.TrimSpace(inputSnapshot.RunID) } if fallbackSessionID != "" { + payload := map[string]any{ + "code": errorCode, + "message": errorMessage, + } + if stopReason != "" { + payload["stop_reason"] = stopReason + } relay.PublishRuntimeEvent(RuntimeEvent{ Type: RuntimeEventTypeRunError, SessionID: fallbackSessionID, RunID: fallbackRunID, - Payload: map[string]any{ - "code": errorCode, - "message": errorMessage, - }, + Payload: payload, }) } } @@ -1683,11 +1691,68 @@ func handleResolvePermissionFrame(ctx context.Context, frame MessageFrame, runti } } +// handleApprovePlanFrame 处理计划批准请求,并把能力收敛到可选 runtime 端口。 +func handleApprovePlanFrame(ctx context.Context, frame MessageFrame, runtimePort RuntimePort) MessageFrame { + if runtimePort == nil { + return runtimePortUnavailableFrame(frame) + } + subjectID, subjectErr := requireAuthenticatedSubjectID(ctx) + if subjectErr != nil { + return errorFrame(frame, subjectErr) + } + approvalPort, approvalErr := requirePlanApprovalRuntimePort(runtimePort) + if approvalErr != nil { + return errorFrame(frame, approvalErr) + } + + input, err := decodeApprovePlanPayload(frame.Payload) + if err != nil { + return errorFrame(frame, err) + } + input.SubjectID = subjectID + if input.SessionID == "" { + input.SessionID = strings.TrimSpace(frame.SessionID) + } + if input.SessionID == "" { + return errorFrame(frame, NewMissingRequiredFieldError("payload.session_id")) + } + if input.PlanID == "" { + return errorFrame(frame, NewMissingRequiredFieldError("payload.plan_id")) + } + if input.Revision <= 0 { + return errorFrame(frame, NewFrameError(ErrorCodeInvalidAction, "invalid approve_plan revision")) + } + + callCtx, cancel := withRuntimeOperationTimeout(ctx) + defer cancel() + result, approveErr := approvalPort.ApprovePlan(callCtx, input) + if approveErr != nil { + return runtimeCallFailedFrame(callCtx, frame, approveErr, "approve_plan") + } + + return MessageFrame{ + Type: FrameTypeAck, + Action: FrameActionApprovePlan, + RequestID: frame.RequestID, + SessionID: input.SessionID, + Payload: result, + } +} + // runtimePortUnavailableFrame 在 runtime 未注入时返回统一错误。 func runtimePortUnavailableFrame(frame MessageFrame) MessageFrame { return errorFrame(frame, NewFrameError(ErrorCodeInternalError, "runtime port is unavailable")) } +// requirePlanApprovalRuntimePort 校验当前 runtime 端口是否支持计划批准能力。 +func requirePlanApprovalRuntimePort(runtimePort RuntimePort) (PlanApprovalRuntimePort, *FrameError) { + approvalPort, ok := runtimePort.(PlanApprovalRuntimePort) + if !ok { + return nil, NewFrameError(ErrorCodeInternalError, "plan approval runtime port is unavailable") + } + return approvalPort, nil +} + // requireManagementRuntimePort 校验当前 runtime 端口是否支持管理面扩展能力。 func requireManagementRuntimePort(runtimePort RuntimePort) (ManagementRuntimePort, *FrameError) { managementPort, ok := runtimePort.(ManagementRuntimePort) @@ -1785,6 +1850,16 @@ func runtimeCallFailedFrame(ctx context.Context, frame MessageFrame, err error, case errors.Is(err, ErrRuntimeResourceNotFound): errorCode = ErrorCodeResourceNotFound message = fmt.Sprintf("%s target not found", normalizedOperation) + case errors.Is(err, ErrRuntimeMaxTurnExceeded): + errorCode = ErrorCodeMaxTurnExceeded + if detail := RuntimeMaxTurnExceededDetail(err); detail != "" { + message = detail + } else { + message = fmt.Sprintf("%s max turn exceeded", normalizedOperation) + } + case errors.Is(err, ErrRuntimeInvalidAction): + errorCode = ErrorCodeInvalidAction + message = fmt.Sprintf("%s invalid action", normalizedOperation) case errors.Is(err, context.DeadlineExceeded): errorCode = ErrorCodeTimeout message = fmt.Sprintf("%s timed out", normalizedOperation) diff --git a/internal/gateway/bootstrap_test.go b/internal/gateway/bootstrap_test.go index 441372589..945201d43 100644 --- a/internal/gateway/bootstrap_test.go +++ b/internal/gateway/bootstrap_test.go @@ -29,6 +29,7 @@ type bootstrapRuntimeStub struct { listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) resolvePermissionFn func(ctx context.Context, input PermissionResolutionInput) error + approvePlanFn func(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) cancelRunFn func(ctx context.Context, input CancelInput) (bool, error) events <-chan RuntimeEvent listSessionsFn func(ctx context.Context) ([]SessionSummary, error) @@ -58,6 +59,10 @@ type bootstrapRuntimeStub struct { checkpointDiffFn func(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) } +type runtimePortWithoutPlanApproval struct { + RuntimePort +} + func (s *bootstrapRuntimeStub) Run(ctx context.Context, input RunInput) error { if s != nil && s.runFn != nil { return s.runFn(ctx, input) @@ -140,6 +145,13 @@ func (s *bootstrapRuntimeStub) ResolvePermission(ctx context.Context, input Perm return nil } +func (s *bootstrapRuntimeStub) ApprovePlan(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + if s != nil && s.approvePlanFn != nil { + return s.approvePlanFn(ctx, input) + } + return ApprovePlanResult{}, nil +} + func (s *bootstrapRuntimeStub) ResolveUserQuestion(ctx context.Context, input UserQuestionAnswerInput) error { return nil } @@ -299,6 +311,14 @@ func (s *bootstrapRuntimeStub) CreateSession(ctx context.Context, input CreateSe return strings.TrimSpace(input.SessionID), nil } +func (s *bootstrapRuntimeStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{SessionID: input.SessionID, AssetID: "asset_test", MimeType: input.MimeType}, nil +} + +func (s *bootstrapRuntimeStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *bootstrapRuntimeStub) ListCheckpoints(ctx context.Context, input ListCheckpointsInput) ([]CheckpointEntry, error) { if s != nil && s.listCheckpointsFn != nil { return s.listCheckpointsFn(ctx, input) @@ -2045,6 +2065,95 @@ ASSERT: } } +func TestDispatchRequestFrameRunMaxTurnFailurePublishesStopReason(t *testing.T) { + relay := NewStreamRelay(StreamRelayOptions{}) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + connectionID := NewConnectionID() + connectionCtx := WithConnectionID(ctx, connectionID) + connectionCtx = WithStreamRelay(connectionCtx, relay) + + messageCh := make(chan RelayMessage, 8) + if err := relay.RegisterConnection(ConnectionRegistration{ + ConnectionID: connectionID, + Channel: StreamChannelIPC, + Context: connectionCtx, + Cancel: cancel, + Write: func(message RelayMessage) error { + messageCh <- message + return nil + }, + Close: func() {}, + }); err != nil { + t.Fatalf("register connection: %v", err) + } + defer relay.dropConnection(connectionID) + + if err := relay.BindConnection(connectionID, StreamBinding{ + SessionID: "run-session-max-turn", + RunID: "run-max-turn", + Channel: StreamChannelIPC, + Role: StreamRoleNone, + Explicit: true, + }); err != nil { + t.Fatalf("bind connection: %v", err) + } + + runtime := &bootstrapRuntimeStub{ + runFn: func(_ context.Context, _ RunInput) error { + return NewRuntimeMaxTurnExceededError("runtime: max turn limit reached (40)") + }, + } + response := dispatchRequestFrame(connectionCtx, MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + RequestID: "req-run-max-turn", + SessionID: "run-session-max-turn", + RunID: "run-max-turn", + InputText: "hello", + }, runtime) + if response.Type != FrameTypeAck { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeAck) + } + + deadline := time.After(2 * time.Second) + for { + select { + case message := <-messageCh: + notification, ok := message.Payload.(protocol.JSONRPCNotification) + if !ok || notification.Method != protocol.MethodGatewayEvent { + continue + } + eventFrame := MessageFrame{} + raw, err := json.Marshal(notification.Params) + if err != nil { + t.Fatalf("marshal payload params: %v", err) + } + if err := json.Unmarshal(raw, &eventFrame); err != nil { + t.Fatalf("unmarshal event frame: %v", err) + } + payloadMap, _ := eventFrame.Payload.(map[string]any) + if strings.TrimSpace(fmt.Sprint(payloadMap["event_type"])) != string(RuntimeEventTypeRunError) { + continue + } + envelope, _ := payloadMap["payload"].(map[string]any) + if got := strings.TrimSpace(fmt.Sprint(envelope["code"])); got != ErrorCodeMaxTurnExceeded.String() { + t.Fatalf("payload.code = %q, want %q", got, ErrorCodeMaxTurnExceeded.String()) + } + if got := strings.TrimSpace(fmt.Sprint(envelope["stop_reason"])); got != ErrorCodeMaxTurnExceeded.String() { + t.Fatalf("payload.stop_reason = %q, want %q", got, ErrorCodeMaxTurnExceeded.String()) + } + if got := strings.TrimSpace(fmt.Sprint(envelope["message"])); got != "runtime: max turn limit reached (40)" { + t.Fatalf("payload.message = %q, want max turn detail", got) + } + return + case <-deadline: + t.Fatal("expected max-turn run_error event") + } + } +} + func TestRuntimeCallFailedFrameSanitizesErrorAndMapsCode(t *testing.T) { var buf bytes.Buffer ctx := WithGatewayLogger(context.Background(), log.New(&buf, "", 0)) @@ -2088,6 +2197,27 @@ func TestRuntimeCallFailedFrameSanitizesErrorAndMapsCode(t *testing.T) { if canceledErr.Error.Message != "run canceled" { t.Fatalf("canceled message = %q, want %q", canceledErr.Error.Message, "run canceled") } + + invalidActionErr := runtimeCallFailedFrame(context.Background(), frame, ErrRuntimeInvalidAction, "approve_plan") + if invalidActionErr.Error == nil || invalidActionErr.Error.Code != ErrorCodeInvalidAction.String() { + t.Fatalf("invalid action payload = %#v, want invalid_action", invalidActionErr.Error) + } + if invalidActionErr.Error.Message != "approve_plan invalid action" { + t.Fatalf("invalid action message = %q, want %q", invalidActionErr.Error.Message, "approve_plan invalid action") + } + + maxTurnErr := runtimeCallFailedFrame( + context.Background(), + frame, + NewRuntimeMaxTurnExceededError("runtime: max turn limit reached (40)"), + "run", + ) + if maxTurnErr.Error == nil || maxTurnErr.Error.Code != ErrorCodeMaxTurnExceeded.String() { + t.Fatalf("max turn error payload = %#v, want max_turn_exceeded", maxTurnErr.Error) + } + if maxTurnErr.Error.Message != "runtime: max turn limit reached (40)" { + t.Fatalf("max turn message = %q, want runtime detail", maxTurnErr.Error.Message) + } } func TestNormalizeRunID(t *testing.T) { @@ -2594,6 +2724,168 @@ func TestHandleCancelListLoadResolveBranches(t *testing.T) { t.Fatalf("response message = %q, want %q", response.Error.Message, "resolve_permission failed") } }) + + t.Run("approve plan invalid payload", func(t *testing.T) { + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + RequestID: "approve-invalid", + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "", + "revision": 1, + }, + }, &bootstrapRuntimeStub{}) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeMissingRequiredField.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeMissingRequiredField.String()) + } + }) + + t.Run("approve plan runtime unavailable", func(t *testing.T) { + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 1, + }, + }, nil) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInternalError.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInternalError.String()) + } + }) + + t.Run("approve plan unsupported runtime port", func(t *testing.T) { + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 1, + }, + }, runtimePortWithoutPlanApproval{RuntimePort: &bootstrapRuntimeStub{}}) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInternalError.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInternalError.String()) + } + }) + + t.Run("approve plan fills session from frame", func(t *testing.T) { + stub := &bootstrapRuntimeStub{ + approvePlanFn: func(_ context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + if input.SessionID != "session-from-frame" { + t.Fatalf("session_id = %q, want frame session", input.SessionID) + } + return ApprovePlanResult{PlanID: input.PlanID, Revision: input.Revision, Status: "approved"}, nil + }, + } + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + SessionID: " session-from-frame ", + Payload: map[string]any{ + "plan_id": "plan-1", + "revision": 1, + }, + }, stub) + if response.Type != FrameTypeAck { + t.Fatalf("response = %#v, want ack", response) + } + if response.SessionID != "session-from-frame" { + t.Fatalf("response session_id = %q, want frame session", response.SessionID) + } + }) + + t.Run("approve plan invalid revision", func(t *testing.T) { + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 0, + }, + }, &bootstrapRuntimeStub{}) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidAction.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInvalidAction.String()) + } + }) + + t.Run("approve plan success", func(t *testing.T) { + stub := &bootstrapRuntimeStub{ + approvePlanFn: func(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("approve plan should use timeout context") + } + if input.SubjectID == "" { + t.Fatal("subject id should be populated") + } + if input.SessionID != "session-1" || input.PlanID != "plan-1" || input.Revision != 2 { + t.Fatalf("approve input = %#v", input) + } + return ApprovePlanResult{PlanID: input.PlanID, Revision: input.Revision, Status: "approved"}, nil + }, + } + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + RequestID: "approve-ok", + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 2, + }, + }, stub) + if response.Type != FrameTypeAck || response.Action != FrameActionApprovePlan { + t.Fatalf("response = %#v, want approve_plan ack", response) + } + payload, ok := response.Payload.(ApprovePlanResult) + if !ok { + t.Fatalf("payload type = %T, want ApprovePlanResult", response.Payload) + } + if payload.Status != "approved" || payload.PlanID != "plan-1" || payload.Revision != 2 { + t.Fatalf("payload = %#v", payload) + } + }) + + t.Run("approve plan runtime error", func(t *testing.T) { + stub := &bootstrapRuntimeStub{ + approvePlanFn: func(_ context.Context, _ ApprovePlanInput) (ApprovePlanResult, error) { + return ApprovePlanResult{}, errors.New("approve failed internals") + }, + } + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 1, + }, + }, stub) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInternalError.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInternalError.String()) + } + if response.Error.Message != "approve_plan failed" { + t.Fatalf("response message = %q, want %q", response.Error.Message, "approve_plan failed") + } + }) } func TestHandleSessionSkillFramesBranches(t *testing.T) { @@ -3799,6 +4091,29 @@ func TestHandleRenameSessionFrameErrors(t *testing.T) { t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInternalError.String()) } }) + + t.Run("approve plan invalid runtime action", func(t *testing.T) { + stub := &bootstrapRuntimeStub{ + approvePlanFn: func(_ context.Context, _ ApprovePlanInput) (ApprovePlanResult, error) { + return ApprovePlanResult{}, ErrRuntimeInvalidAction + }, + } + response := handleApprovePlanFrame(context.Background(), MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 1, + }, + }, stub) + if response.Type != FrameTypeError { + t.Fatalf("response type = %q, want %q", response.Type, FrameTypeError) + } + if response.Error == nil || response.Error.Code != ErrorCodeInvalidAction.String() { + t.Fatalf("response error = %#v, want %q", response.Error, ErrorCodeInvalidAction.String()) + } + }) } func TestHandleListFilesFrameErrors(t *testing.T) { @@ -5028,6 +5343,12 @@ func (runtimeOnlyStub) GetRuntimeSnapshot(ctx context.Context, input GetRuntimeS func (runtimeOnlyStub) CreateSession(ctx context.Context, input CreateSessionInput) (string, error) { return "", nil } +func (runtimeOnlyStub) SaveSessionAsset(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (runtimeOnlyStub) OpenSessionAsset(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (runtimeOnlyStub) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { return false, nil } diff --git a/internal/gateway/contracts.go b/internal/gateway/contracts.go index 9812b0f85..388a25402 100644 --- a/internal/gateway/contracts.go +++ b/internal/gateway/contracts.go @@ -2,6 +2,7 @@ package gateway import ( "context" + "io" "time" "neo-code/internal/config" @@ -49,6 +50,28 @@ type PermissionResolutionInput struct { Decision PermissionResolutionDecision `json:"decision"` } +// ApprovePlanInput 表示批准当前计划 draft revision 的输入。 +type ApprovePlanInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string `json:"subject_id,omitempty"` + // SessionID 是计划所属会话标识。 + SessionID string `json:"session_id"` + // PlanID 是目标计划标识。 + PlanID string `json:"plan_id"` + // Revision 是待批准的计划 revision。 + Revision int `json:"revision"` +} + +// ApprovePlanResult 表示批准计划后的稳定返回结构。 +type ApprovePlanResult struct { + // PlanID 是已批准计划标识。 + PlanID string `json:"plan_id"` + // Revision 是已批准 revision。 + Revision int `json:"revision"` + // Status 是批准后的计划状态,当前固定为 approved。 + Status string `json:"status"` +} + // RunInput 表示网关向下游运行端口发起 run 动作时的输入。 type RunInput struct { // SubjectID 是请求方身份主体标识。 @@ -205,6 +228,58 @@ type CreateSessionInput struct { SessionID string } +// SessionAssetMeta 描述 Gateway 可见的会话附件元数据。 +type SessionAssetMeta struct { + // SessionID 是附件所属会话标识。 + SessionID string `json:"session_id"` + // AssetID 是附件标识。 + AssetID string `json:"asset_id"` + // MimeType 是服务端确认后的 MIME 类型。 + MimeType string `json:"mime_type"` + // Size 是附件原始字节数。 + Size int64 `json:"size"` +} + +// SaveSessionAssetInput 表示保存浏览器上传附件的下游输入。 +type SaveSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // Reader 是附件二进制内容。 + Reader io.Reader + // MimeType 是服务端探测确认后的 MIME 类型。 + MimeType string +} + +// OpenSessionAssetInput 表示读取会话附件的下游输入。 +type OpenSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // AssetID 是附件标识。 + AssetID string +} + +// DeleteSessionAssetInput 表示删除会话附件的下游输入。 +type DeleteSessionAssetInput struct { + // SubjectID 是请求方身份主体标识。 + SubjectID string + // SessionID 是附件所属会话标识。 + SessionID string + // AssetID 是附件标识。 + AssetID string +} + +// OpenSessionAssetResult 表示打开会话附件后的读取结果。 +type OpenSessionAssetResult struct { + // Reader 是附件内容流,调用方负责关闭。 + Reader io.ReadCloser + // Meta 是附件元数据。 + Meta SessionAssetMeta +} + // DeleteSessionInput 表示 gateway.deleteSession 动作的下游输入。 type DeleteSessionInput struct { // SubjectID 是请求方身份主体标识。 @@ -672,6 +747,8 @@ type SessionMessage struct { Role string `json:"role"` // Content 是消息内容。 Content string `json:"content"` + // Parts 是消息的结构化多模态分片,供支持图片的客户端渲染。 + Parts []InputPart `json:"parts,omitempty"` // ToolCalls 是 assistant 发起的工具调用元数据。 ToolCalls []ToolCall `json:"tool_calls,omitempty"` // ToolCallID 是工具消息关联的调用标识。 @@ -680,6 +757,46 @@ type SessionMessage struct { IsError bool `json:"is_error,omitempty"` } +// PlanTodoItem 表示计划正文中保留的 legacy todo 项,仅用于展示和兼容读取。 +type PlanTodoItem struct { + ID string `json:"id"` + Content string `json:"content"` + Status string `json:"status,omitempty"` + Required bool `json:"required,omitempty"` + Artifacts []string `json:"artifacts,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` + BlockedReason string `json:"blocked_reason,omitempty"` + Revision int64 `json:"revision,omitempty"` +} + +// PlanSpec 表示当前完整计划的公开结构。 +type PlanSpec struct { + Goal string `json:"goal"` + Steps []string `json:"steps,omitempty"` + Constraints []string `json:"constraints,omitempty"` + Todos []PlanTodoItem `json:"todos,omitempty"` + OpenQuestions []string `json:"open_questions,omitempty"` +} + +// PlanSummaryView 表示完整计划的紧凑摘要。 +type PlanSummaryView struct { + Goal string `json:"goal"` + KeySteps []string `json:"key_steps,omitempty"` + Constraints []string `json:"constraints,omitempty"` + ActiveTodoIDs []string `json:"active_todo_ids,omitempty"` +} + +// PlanArtifact 表示会话当前计划快照。 +type PlanArtifact struct { + ID string `json:"id"` + Revision int `json:"revision"` + Status string `json:"status"` + Spec PlanSpec `json:"spec"` + Summary PlanSummaryView `json:"summary"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + // Session 表示网关视角的会话详情。 type Session struct { // ID 是会话标识。 @@ -698,6 +815,8 @@ type Session struct { Model string `json:"model,omitempty"` // AgentMode 是会话当前 Agent 工作模式。 AgentMode string `json:"agent_mode,omitempty"` + // CurrentPlan 是会话当前结构化计划快照。 + CurrentPlan *PlanArtifact `json:"current_plan,omitempty"` // Messages 是会话消息快照。 Messages []SessionMessage `json:"messages,omitempty"` } @@ -884,6 +1003,22 @@ type RuntimePort interface { CheckpointDiff(ctx context.Context, input CheckpointDiffInput) (CheckpointDiffResult, error) } +// SessionAssetPort 定义 Gateway HTTP 资产端点访问会话附件的独立下游端口。 +type SessionAssetPort interface { + // SaveSessionAsset 保存会话附件并返回元数据。 + SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) + // OpenSessionAsset 打开会话附件供 HTTP 读取接口返回。 + OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) + // DeleteSessionAsset 删除已上传但不再需要的会话附件。 + DeleteSessionAsset(ctx context.Context, input DeleteSessionAssetInput) error +} + +// PlanApprovalRuntimePort 定义批准计划的可选下游能力。 +type PlanApprovalRuntimePort interface { + // ApprovePlan 将指定 draft 计划 revision 推进到 approved。 + ApprovePlan(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) +} + // ManagementRuntimePort 定义前端管理面访问配置能力的可选下游端口。 type ManagementRuntimePort interface { // ListProviders 列出可管理 provider。 diff --git a/internal/gateway/contracts_test.go b/internal/gateway/contracts_test.go index de1ef52cb..f13f1a99d 100644 --- a/internal/gateway/contracts_test.go +++ b/internal/gateway/contracts_test.go @@ -147,6 +147,18 @@ func (s *runtimePortCompileStub) CreateSession(_ context.Context, _ CreateSessio return "", nil } +func (s *runtimePortCompileStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *runtimePortCompileStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + +func (s *runtimePortCompileStub) DeleteSessionAsset(_ context.Context, _ DeleteSessionAssetInput) error { + return nil +} + func (s *runtimePortCompileStub) ListCheckpoints(_ context.Context, _ ListCheckpointsInput) ([]CheckpointEntry, error) { return nil, nil } @@ -164,5 +176,6 @@ func (s *runtimePortCompileStub) CheckpointDiff(_ context.Context, _ CheckpointD } var _ RuntimePort = (*runtimePortCompileStub)(nil) +var _ SessionAssetPort = (*runtimePortCompileStub)(nil) var _ TransportAdapter = (*Server)(nil) var _ TransportAdapter = (*NetworkServer)(nil) diff --git a/internal/gateway/errors.go b/internal/gateway/errors.go index 46c667c43..99417f82b 100644 --- a/internal/gateway/errors.go +++ b/internal/gateway/errors.go @@ -18,6 +18,8 @@ const ( ErrorCodeUnsupportedAction ErrorCode = "unsupported_action" // ErrorCodeInternalError 表示网关内部错误。 ErrorCodeInternalError ErrorCode = "internal_error" + // ErrorCodeMaxTurnExceeded 表示 runtime 达到单次运行最大轮数后受控停止。 + ErrorCodeMaxTurnExceeded ErrorCode = "max_turn_exceeded" // ErrorCodeTimeout 表示网关下游调用超时。 ErrorCodeTimeout ErrorCode = "timeout" // ErrorCodeUnauthorized 表示请求未通过认证校验。 @@ -41,6 +43,7 @@ var stableErrorCodes = map[string]struct{}{ string(ErrorCodeMissingRequiredField): {}, string(ErrorCodeUnsupportedAction): {}, string(ErrorCodeInternalError): {}, + string(ErrorCodeMaxTurnExceeded): {}, string(ErrorCodeTimeout): {}, string(ErrorCodeUnauthorized): {}, string(ErrorCodeAccessDenied): {}, diff --git a/internal/gateway/errors_test.go b/internal/gateway/errors_test.go index b394ad7be..c42120aa1 100644 --- a/internal/gateway/errors_test.go +++ b/internal/gateway/errors_test.go @@ -10,9 +10,14 @@ func TestStableErrorCodes(t *testing.T) { ErrorCodeMissingRequiredField, ErrorCodeUnsupportedAction, ErrorCodeInternalError, + ErrorCodeMaxTurnExceeded, ErrorCodeTimeout, ErrorCodeUnauthorized, ErrorCodeAccessDenied, + ErrorCodeResourceNotFound, + ErrorCodeRunnerOffline, + ErrorCodeCapabilityDenied, + ErrorCodeToolExecutionFailed, } for _, code := range codes { diff --git a/internal/gateway/metrics.go b/internal/gateway/metrics.go index fabfdc7ce..ccc9acb34 100644 --- a/internal/gateway/metrics.go +++ b/internal/gateway/metrics.go @@ -32,6 +32,7 @@ var allowedRPCMethodMetricLabels = map[string]struct{}{ strings.ToLower(protocol.MethodGatewayListSessionTodos): {}, strings.ToLower(protocol.MethodGatewayGetRuntimeSnapshot): {}, strings.ToLower(protocol.MethodGatewayResolvePermission): {}, + strings.ToLower(protocol.MethodGatewayApprovePlan): {}, strings.ToLower(protocol.MethodGatewayDeleteSession): {}, strings.ToLower(protocol.MethodGatewayRenameSession): {}, strings.ToLower(protocol.MethodGatewayListFiles): {}, diff --git a/internal/gateway/multi_workspace_runtime.go b/internal/gateway/multi_workspace_runtime.go index f0674d862..f34b02fd4 100644 --- a/internal/gateway/multi_workspace_runtime.go +++ b/internal/gateway/multi_workspace_runtime.go @@ -345,6 +345,19 @@ func (m *MultiWorkspaceRuntime) ResolvePermission(ctx context.Context, input Per return port.ResolvePermission(ctx, input) } +// ApprovePlan 将计划批准请求路由到当前工作区 RuntimePort 的可选计划审批能力。 +func (m *MultiWorkspaceRuntime) ApprovePlan(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + port, err := m.getPort(ctx) + if err != nil { + return ApprovePlanResult{}, err + } + approvalPort, ok := port.(PlanApprovalRuntimePort) + if !ok { + return ApprovePlanResult{}, fmt.Errorf("plan approval runtime port is unavailable") + } + return approvalPort.ApprovePlan(ctx, input) +} + func (m *MultiWorkspaceRuntime) ResolveUserQuestion(ctx context.Context, input UserQuestionAnswerInput) error { port, err := m.getPort(ctx) if err != nil { @@ -389,6 +402,43 @@ func (m *MultiWorkspaceRuntime) CreateSession(ctx context.Context, input CreateS return port.CreateSession(ctx, input) } +func (m *MultiWorkspaceRuntime) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + port, err := m.getPort(ctx) + if err != nil { + return SessionAssetMeta{}, err + } + assetPort, ok := port.(SessionAssetPort) + if !ok { + return SessionAssetMeta{}, ErrRuntimeUnavailable + } + return assetPort.SaveSessionAsset(ctx, input) +} + +func (m *MultiWorkspaceRuntime) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + port, err := m.getPort(ctx) + if err != nil { + return OpenSessionAssetResult{}, err + } + assetPort, ok := port.(SessionAssetPort) + if !ok { + return OpenSessionAssetResult{}, ErrRuntimeUnavailable + } + return assetPort.OpenSessionAsset(ctx, input) +} + +// DeleteSessionAsset 按请求上下文中的工作区选择对应运行桥,并转发会话附件删除。 +func (m *MultiWorkspaceRuntime) DeleteSessionAsset(ctx context.Context, input DeleteSessionAssetInput) error { + port, err := m.getPort(ctx) + if err != nil { + return err + } + assetPort, ok := port.(SessionAssetPort) + if !ok { + return ErrRuntimeUnavailable + } + return assetPort.DeleteSessionAsset(ctx, input) +} + func (m *MultiWorkspaceRuntime) DeleteSession(ctx context.Context, input DeleteSessionInput) (bool, error) { port, err := m.getPort(ctx) if err != nil { diff --git a/internal/gateway/multi_workspace_runtime_test.go b/internal/gateway/multi_workspace_runtime_test.go index fb3c65f32..13cbc3a83 100644 --- a/internal/gateway/multi_workspace_runtime_test.go +++ b/internal/gateway/multi_workspace_runtime_test.go @@ -5,6 +5,7 @@ import ( "errors" "os" "path/filepath" + "strings" "sync" "sync/atomic" "testing" @@ -24,8 +25,12 @@ type recordingPort struct { runCalls atomic.Int32 listSessionsCalls atomic.Int32 executeSysCalls atomic.Int32 + approvePlanCalls atomic.Int32 resolveUserCalls atomic.Int32 cancelCalls atomic.Int32 + saveAssetCalls atomic.Int32 + openAssetCalls atomic.Int32 + deleteAssetCalls atomic.Int32 closed atomic.Int32 closeOnce sync.Once @@ -89,6 +94,15 @@ func (p *recordingPort) ResolvePermission(_ context.Context, _ PermissionResolut return nil } +func (p *recordingPort) ApprovePlan(_ context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + p.approvePlanCalls.Add(1) + return ApprovePlanResult{ + PlanID: input.PlanID, + Revision: input.Revision, + Status: "approved", + }, nil +} + func (p *recordingPort) ResolveUserQuestion(_ context.Context, _ UserQuestionAnswerInput) error { p.resolveUserCalls.Add(1) return nil @@ -124,6 +138,21 @@ func (p *recordingPort) CreateSession(_ context.Context, _ CreateSessionInput) ( return p.id, nil } +func (p *recordingPort) SaveSessionAsset(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + p.saveAssetCalls.Add(1) + return SessionAssetMeta{SessionID: input.SessionID, AssetID: p.id, MimeType: input.MimeType}, nil +} + +func (p *recordingPort) OpenSessionAsset(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + p.openAssetCalls.Add(1) + return OpenSessionAssetResult{Meta: SessionAssetMeta{SessionID: input.SessionID, AssetID: input.AssetID}}, nil +} + +func (p *recordingPort) DeleteSessionAsset(_ context.Context, _ DeleteSessionAssetInput) error { + p.deleteAssetCalls.Add(1) + return nil +} + func (p *recordingPort) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return true, nil } @@ -508,6 +537,73 @@ func TestMultiWorkspaceRuntime_ResolveUserQuestionRoutesByWorkspace(t *testing.T } } +func TestMultiWorkspaceRuntime_ApprovePlanRoutesByWorkspace(t *testing.T) { + idx, alpha, beta := setupIndex(t) + builder := newTestBuilder() + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, builder.build) + t.Cleanup(func() { _ = mw.Close() }) + + result, err := mw.ApprovePlan(ctxWithHash(t, beta.Hash), ApprovePlanInput{ + SessionID: "session-1", + PlanID: "plan-1", + Revision: 2, + }) + if err != nil { + t.Fatalf("ApprovePlan: %v", err) + } + if result.PlanID != "plan-1" || result.Revision != 2 || result.Status != "approved" { + t.Fatalf("ApprovePlan result = %#v", result) + } + + betaPort := builder.portFor(beta.Path) + if betaPort == nil { + t.Fatalf("beta port should be built") + } + if got := betaPort.approvePlanCalls.Load(); got != 1 { + t.Fatalf("beta approve plan calls = %d, want 1", got) + } + if alphaPort := builder.portFor(alpha.Path); alphaPort != nil && alphaPort.approvePlanCalls.Load() != 0 { + t.Fatalf("alpha approve plan should not be called, got %d", alphaPort.approvePlanCalls.Load()) + } +} + +func TestMultiWorkspaceRuntime_ApprovePlanErrors(t *testing.T) { + t.Run("workspace not found", func(t *testing.T) { + idx, alpha, _ := setupIndex(t) + builder := newTestBuilder() + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, builder.build) + t.Cleanup(func() { _ = mw.Close() }) + + _, err := mw.ApprovePlan(ctxWithHash(t, "missing-workspace"), ApprovePlanInput{ + SessionID: "session-1", + PlanID: "plan-1", + Revision: 1, + }) + if !errors.Is(err, ErrRuntimeResourceNotFound) { + t.Fatalf("ApprovePlan error = %v, want ErrRuntimeResourceNotFound", err) + } + }) + + t.Run("runtime port does not support plan approval", func(t *testing.T) { + idx, alpha, beta := setupIndex(t) + builder := newTestBuilder() + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, builder.build) + mw.PreloadWorkspaceBundle(beta.Hash, runtimePortWithoutPlanApproval{ + RuntimePort: newRecordingPort("beta"), + }, func() error { return nil }) + t.Cleanup(func() { _ = mw.Close() }) + + _, err := mw.ApprovePlan(ctxWithHash(t, beta.Hash), ApprovePlanInput{ + SessionID: "session-1", + PlanID: "plan-1", + Revision: 1, + }) + if err == nil || !strings.Contains(err.Error(), "plan approval runtime port is unavailable") { + t.Fatalf("ApprovePlan error = %v, want unsupported plan approval", err) + } + }) +} + func TestMultiWorkspaceRuntime_CreatePersistsIndex(t *testing.T) { idx, alpha, _ := setupIndex(t) builder := newTestBuilder() @@ -705,6 +801,15 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if _, err := mw.ExecuteSystemTool(alphaCtx, ExecuteSystemToolInput{}); err != nil { t.Fatalf("ExecuteSystemTool alpha: %v", err) } + if _, err := mw.SaveSessionAsset(betaCtx, SaveSessionAssetInput{SessionID: "s-1", MimeType: "image/png"}); err != nil { + t.Fatalf("SaveSessionAsset beta: %v", err) + } + if _, err := mw.OpenSessionAsset(alphaCtx, OpenSessionAssetInput{SessionID: "s-1", AssetID: "asset-1"}); err != nil { + t.Fatalf("OpenSessionAsset alpha: %v", err) + } + if err := mw.DeleteSessionAsset(betaCtx, DeleteSessionAssetInput{SessionID: "s-1", AssetID: "asset-1"}); err != nil { + t.Fatalf("DeleteSessionAsset beta: %v", err) + } alphaPort := builder.portFor(alpha.Path) betaPort := builder.portFor(beta.Path) @@ -723,6 +828,15 @@ func TestMultiWorkspaceRuntime_RoutingMatrix(t *testing.T) { if got := alphaPort.executeSysCalls.Load(); got != 1 { t.Fatalf("alpha ExecuteSystemTool calls = %d, want 1", got) } + if got := betaPort.saveAssetCalls.Load(); got != 1 { + t.Fatalf("beta SaveSessionAsset calls = %d, want 1", got) + } + if got := alphaPort.openAssetCalls.Load(); got != 1 { + t.Fatalf("alpha OpenSessionAsset calls = %d, want 1", got) + } + if got := betaPort.deleteAssetCalls.Load(); got != 1 { + t.Fatalf("beta DeleteSessionAsset calls = %d, want 1", got) + } } func TestMultiWorkspaceRuntime_ListWorkspacesMatchesIndex(t *testing.T) { @@ -746,7 +860,40 @@ func TestMultiWorkspaceRuntime_ListWorkspacesMatchesIndex(t *testing.T) { // guard against future drift: MultiWorkspaceRuntime must implement RuntimePort and ManagementRuntimePort. var _ RuntimePort = (*MultiWorkspaceRuntime)(nil) +var _ SessionAssetPort = (*MultiWorkspaceRuntime)(nil) var _ ManagementRuntimePort = (*MultiWorkspaceRuntime)(nil) +var _ PlanApprovalRuntimePort = (*MultiWorkspaceRuntime)(nil) + +// recordingPortWithoutSessionAsset 嵌入 RuntimePort 接口(而非具体类型 *recordingPort), +// 确保只有 RuntimePort 方法被提升、SessionAssetPort 方法不被提升, +// 用于验证 MultiWorkspaceRuntime 在底层 runtime 不支持附件时的降级处理。 +type recordingPortWithoutSessionAsset struct{ RuntimePort } + +func TestMultiWorkspaceRuntime_DeleteSessionAssetUnsupportedRuntime(t *testing.T) { + idx, alpha, _ := setupIndex(t) + builder := newTestBuilder() + // 将 alpha 的 port 替换为不支持 SessionAssetPort 的版本 + alphaPort := newRecordingPort("alpha-no-asset") + builder.ports[alpha.Path] = alphaPort + mw := NewMultiWorkspaceRuntime(idx, alpha.Hash, func(ctx context.Context, workdir string) (RuntimePort, func() error, error) { + port, cleanup, err := builder.build(ctx, workdir) + if err != nil { + return nil, nil, err + } + rp := port.(*recordingPort) + return &recordingPortWithoutSessionAsset{rp}, cleanup, nil + }) + t.Cleanup(func() { _ = mw.Close() }) + + alphaCtx := ctxWithHash(t, alpha.Hash) + err := mw.DeleteSessionAsset(alphaCtx, DeleteSessionAssetInput{SessionID: "s-1", AssetID: "a-1"}) + if err == nil { + t.Fatal("expected error when runtime does not implement SessionAssetPort") + } + if !errors.Is(err, ErrRuntimeUnavailable) { + t.Fatalf("error = %v, want ErrRuntimeUnavailable", err) + } +} // guard helper: ensure recordingPort builds correctly under sync access. func TestRecordingPort_Concurrent(t *testing.T) { diff --git a/internal/gateway/network_server.go b/internal/gateway/network_server.go index cc78a2bf8..cefc5a483 100644 --- a/internal/gateway/network_server.go +++ b/internal/gateway/network_server.go @@ -12,6 +12,7 @@ import ( "net" "net/http" "os" + "path" "strconv" "strings" "sync" @@ -21,6 +22,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) const ( @@ -40,6 +42,8 @@ const ( DefaultNetworkMaxStreamConnections = 128 // DefaultWSUnauthenticatedTimeout 定义 WS 未认证连接的最大等待时间。 DefaultWSUnauthenticatedTimeout = 3 * time.Second + // SessionAssetWorkspaceHeader 定义 Web 上传/读取会话附件时携带当前工作区的 HTTP Header。 + SessionAssetWorkspaceHeader = "X-NeoCode-Workspace-Hash" ) var ( @@ -367,6 +371,12 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { mux.HandleFunc("/rpc", func(writer http.ResponseWriter, request *http.Request) { s.handleRPCRequest(writer, request, runtimePort) }) + mux.HandleFunc("/api/session-assets", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetUpload(writer, request, runtimePort) + }) + mux.HandleFunc("/api/session-assets/", func(writer http.ResponseWriter, request *http.Request) { + s.handleSessionAssetRequest(writer, request, runtimePort) + }) mux.Handle("/ws", websocket.Server{ Handshake: func(_ *websocket.Config, request *http.Request) error { return s.validateWebSocketOrigin(request) @@ -387,6 +397,286 @@ func (s *NetworkServer) buildHandler(runtimePort RuntimePort) http.Handler { return mux } +// handleSessionAssetUpload 接收浏览器上传图片,并保存为当前会话的 session asset。 +func (s *NetworkServer) handleSessionAssetUpload(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + if request.Method != http.MethodPost { + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + return + } + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if !s.isHTTPControlPlaneMethodAllowed(sessionAssetUploadMethod) { + s.writeHTTPAccessDenied(writer, sessionAssetUploadMethod) + return + } + assetPort, ok := runtimePort.(SessionAssetPort) + if runtimePort == nil || !ok { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + limit := agentsession.MaxSessionAssetBytes + request.Body = http.MaxBytesReader(writer, request.Body, limit+(1<<20)) + if err := request.ParseMultipartForm(limit + 4096); err != nil { + if strings.Contains(strings.ToLower(err.Error()), "too large") { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "invalid multipart form"}) + return + } + + sessionID := strings.TrimSpace(request.FormValue("session_id")) + if sessionID == "" { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "session_id is required"}) + return + } + + file, _, err := request.FormFile("file") + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is required"}) + return + } + defer func() { + _ = file.Close() + }() + + payload, err := io.ReadAll(io.LimitReader(file, limit+1)) + if err != nil { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "read uploaded file failed"}) + return + } + if len(payload) == 0 { + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": "file is empty"}) + return + } + if int64(len(payload)) > limit { + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": "asset is too large"}) + return + } + + mimeType := detectAllowedUploadImageMime(payload) + if mimeType == "" { + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": "unsupported image type"}) + return + } + + meta, err := assetPort.SaveSessionAsset(sessionAssetRequestContext(request), SaveSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + Reader: bytes.NewReader(payload), + MimeType: mimeType, + }) + if err != nil { + writeSessionAssetUploadHTTPError(writer, err) + return + } + writeJSONResponse(writer, http.StatusOK, meta) +} + +// handleSessionAssetRequest 按 HTTP 方法分发会话附件读取或删除请求。 +func (s *NetworkServer) handleSessionAssetRequest(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + switch request.Method { + case http.MethodGet: + s.handleSessionAssetRead(writer, request, runtimePort) + case http.MethodDelete: + s.handleSessionAssetDelete(writer, request, runtimePort) + default: + http.Error(writer, "method not allowed", http.StatusMethodNotAllowed) + } +} + +// handleSessionAssetRead 读取会话图片附件,供 Web 历史消息缩略图展示。 +func (s *NetworkServer) handleSessionAssetRead(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if !s.isHTTPControlPlaneMethodAllowed(sessionAssetReadMethod) { + s.writeHTTPAccessDenied(writer, sessionAssetReadMethod) + return + } + assetPort, ok := runtimePort.(SessionAssetPort) + if runtimePort == nil || !ok { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + sessionID, assetID, ok := parseSessionAssetPath(request.URL.Path) + if !ok { + http.NotFound(writer, request) + return + } + result, err := assetPort.OpenSessionAsset(sessionAssetRequestContext(request), OpenSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + AssetID: assetID, + }) + if err != nil { + writeSessionAssetReadHTTPError(writer, err) + return + } + defer func() { + _ = result.Reader.Close() + }() + + writer.Header().Set("Content-Type", result.Meta.MimeType) + if result.Meta.Size > 0 { + writer.Header().Set("Content-Length", strconv.FormatInt(result.Meta.Size, 10)) + } + writer.Header().Set("Cache-Control", "private, max-age=300") + _, _ = io.Copy(writer, result.Reader) +} + +// handleSessionAssetDelete 删除用户已上传但不再需要的会话图片附件。 +func (s *NetworkServer) handleSessionAssetDelete(writer http.ResponseWriter, request *http.Request, runtimePort RuntimePort) { + subjectID, ok := s.authenticatedHTTPSubjectID(request) + if !ok { + http.Error(writer, "unauthorized", http.StatusUnauthorized) + return + } + if !s.isHTTPControlPlaneMethodAllowed(sessionAssetDeleteMethod) { + s.writeHTTPAccessDenied(writer, sessionAssetDeleteMethod) + return + } + assetPort, ok := runtimePort.(SessionAssetPort) + if runtimePort == nil || !ok { + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + return + } + + sessionID, assetID, ok := parseSessionAssetPath(request.URL.Path) + if !ok { + http.NotFound(writer, request) + return + } + if err := assetPort.DeleteSessionAsset(sessionAssetRequestContext(request), DeleteSessionAssetInput{ + SubjectID: subjectID, + SessionID: sessionID, + AssetID: assetID, + }); err != nil { + writeSessionAssetReadHTTPError(writer, err) + return + } + writeJSONResponse(writer, http.StatusOK, map[string]bool{"deleted": true}) +} + +// sessionAssetRequestContext 将 HTTP Header 中的工作区哈希注入请求上下文,供多工作区 Runtime 路由。 +func sessionAssetRequestContext(request *http.Request) context.Context { + if request == nil { + return context.Background() + } + workspaceHash := strings.TrimSpace(request.Header.Get(SessionAssetWorkspaceHeader)) + if workspaceHash == "" { + return request.Context() + } + state := NewConnectionWorkspaceState() + state.SetWorkspaceHash(workspaceHash) + return WithConnectionWorkspaceState(request.Context(), state) +} + +// authenticatedHTTPSubjectID 校验 HTTP Bearer Token 并返回主体标识。 +func (s *NetworkServer) authenticatedHTTPSubjectID(request *http.Request) (string, bool) { + if s.authenticator == nil { + return "", false + } + token := extractBearerToken(request.Header.Get("Authorization")) + subjectID, ok := s.authenticator.ResolveSubjectID(token) + if !ok || strings.TrimSpace(subjectID) == "" { + return "", false + } + return strings.TrimSpace(subjectID), true +} + +// isHTTPControlPlaneMethodAllowed 按 HTTP 来源复用控制面 ACL,覆盖非 JSON-RPC 的 HTTP 端点。 +func (s *NetworkServer) isHTTPControlPlaneMethodAllowed(method string) bool { + if s == nil || s.acl == nil { + return true + } + return s.acl.IsAllowed(RequestSourceHTTP, method) +} + +// writeHTTPAccessDenied 记录 HTTP 端点 ACL 拒绝并返回统一的 403 JSON 响应。 +func (s *NetworkServer) writeHTTPAccessDenied(writer http.ResponseWriter, method string) { + if s != nil && s.metrics != nil { + s.metrics.IncACLDenied(string(RequestSourceHTTP), method) + } + writeJSONResponse(writer, http.StatusForbidden, map[string]string{"error": "access denied"}) +} + +// detectAllowedUploadImageMime 用文件头确认上传图片类型,只允许 PNG/JPEG/WebP。 +func detectAllowedUploadImageMime(payload []byte) string { + if len(payload) == 0 { + return "" + } + probe := payload + if len(probe) > 512 { + probe = probe[:512] + } + mimeType := strings.ToLower(strings.TrimSpace(http.DetectContentType(probe))) + switch mimeType { + case "image/png", "image/jpeg", "image/webp": + return mimeType + default: + return "" + } +} + +// parseSessionAssetPath 从 /api/session-assets/{session_id}/{asset_id} 提取路径参数。 +func parseSessionAssetPath(rawPath string) (string, string, bool) { + cleanPath := path.Clean("/" + strings.TrimSpace(rawPath)) + const prefix = "/api/session-assets/" + if !strings.HasPrefix(cleanPath, prefix) { + return "", "", false + } + parts := strings.Split(strings.TrimPrefix(cleanPath, prefix), "/") + if len(parts) != 2 { + return "", "", false + } + sessionID := strings.TrimSpace(parts[0]) + assetID := strings.TrimSpace(parts[1]) + return sessionID, assetID, sessionID != "" && assetID != "" +} + +// writeSessionAssetUploadHTTPError 将上传阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetUploadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "session not found") +} + +// writeSessionAssetReadHTTPError 将读取阶段的下游错误映射为明确 HTTP 状态。 +func writeSessionAssetReadHTTPError(writer http.ResponseWriter, err error) { + writeSessionAssetHTTPError(writer, err, "asset not found") +} + +// writeSessionAssetHTTPError 将下游附件错误映射为明确 HTTP 状态。 +func writeSessionAssetHTTPError(writer http.ResponseWriter, err error, notFoundMessage string) { + if err == nil { + writeJSONResponse(writer, http.StatusInternalServerError, map[string]string{"error": "unknown asset error"}) + return + } + message := strings.ToLower(err.Error()) + switch { + case strings.Contains(message, "workspace") && strings.Contains(message, "not found"): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": "workspace not found"}) + case errors.Is(err, ErrRuntimeUnavailable): + writeJSONResponse(writer, http.StatusServiceUnavailable, map[string]string{"error": "runtime unavailable"}) + case errors.Is(err, os.ErrNotExist) || errors.Is(err, ErrRuntimeResourceNotFound): + writeJSONResponse(writer, http.StatusNotFound, map[string]string{"error": notFoundMessage}) + case strings.Contains(message, "asset size exceeds"): + writeJSONResponse(writer, http.StatusRequestEntityTooLarge, map[string]string{"error": err.Error()}) + case strings.Contains(message, "unsupported") || strings.Contains(message, "not an image"): + writeJSONResponse(writer, http.StatusUnsupportedMediaType, map[string]string{"error": err.Error()}) + case strings.Contains(message, "access denied"): + writeJSONResponse(writer, http.StatusForbidden, map[string]string{"error": "access denied"}) + default: + writeJSONResponse(writer, http.StatusBadRequest, map[string]string{"error": err.Error()}) + } +} + // withCORS 为网络入口注入 CORS 头,仅对白名单 Origin 回显允许值。 // WebSocket 升级请求不受 CORS 约束,直接放行交予 WS 握手阶段的 Origin 校验。 func (s *NetworkServer) withCORS(next http.Handler) http.Handler { @@ -406,7 +696,7 @@ func (s *NetworkServer) withCORS(next http.Handler) http.Handler { } writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") - writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") + writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, "+SessionAssetWorkspaceHeader) if request.Method == http.MethodOptions { writer.WriteHeader(http.StatusNoContent) return diff --git a/internal/gateway/network_server_test.go b/internal/gateway/network_server_test.go index 96301525c..c1024eb51 100644 --- a/internal/gateway/network_server_test.go +++ b/internal/gateway/network_server_test.go @@ -2,10 +2,13 @@ package gateway import ( "bufio" + "bytes" "context" "encoding/json" + "fmt" "io" "log" + "mime/multipart" "net/http" "net/http/httptest" "strings" @@ -15,6 +18,7 @@ import ( "golang.org/x/net/websocket" "neo-code/internal/gateway/protocol" + agentsession "neo-code/internal/session" ) func TestResolveNetworkListenAddress(t *testing.T) { @@ -400,6 +404,468 @@ func TestNetworkServerRPCErrorBranches(t *testing.T) { }) } +func TestNetworkServerSessionAssetUploadAndRead(t *testing.T) { + payload := gatewayMinimalPNGBytes() + var capturedUpload SaveSessionAssetInput + var capturedDelete DeleteSessionAssetInput + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(_ context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + capturedUpload = input + got, err := io.ReadAll(input.Reader) + if err != nil { + t.Fatalf("read uploaded asset: %v", err) + } + if !bytes.Equal(got, payload) { + t.Fatalf("uploaded payload mismatch") + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(got)), + }, nil + }, + openAssetFn: func(_ context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if input.SubjectID != "local_admin" || input.SessionID != "session-1" || input.AssetID != "asset-1" { + t.Fatalf("open input = %+v, want subject/session/asset", input) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + deleteAssetFn: func(_ context.Context, input DeleteSessionAssetInput) error { + capturedDelete = input + return nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + var uploadResponse SessionAssetMeta + if err := json.Unmarshal(uploadRecorder.Body.Bytes(), &uploadResponse); err != nil { + t.Fatalf("decode upload response: %v", err) + } + if uploadResponse.AssetID != "asset-1" || uploadResponse.MimeType != "image/png" || uploadResponse.Size != int64(len(payload)) { + t.Fatalf("upload response = %+v", uploadResponse) + } + if capturedUpload.SubjectID != "local_admin" || capturedUpload.SessionID != "session-1" || capturedUpload.MimeType != "image/png" { + t.Fatalf("captured upload = %+v", capturedUpload) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } + if got := readRecorder.Header().Get("Content-Type"); got != "image/png" { + t.Fatalf("read content-type = %q, want image/png", got) + } + if !bytes.Equal(readRecorder.Body.Bytes(), payload) { + t.Fatalf("read payload mismatch") + } + + deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil) + deleteRequest.Header.Set("Authorization", "Bearer gateway-token") + deleteRecorder := httptest.NewRecorder() + handler.ServeHTTP(deleteRecorder, deleteRequest) + if deleteRecorder.Code != http.StatusOK { + t.Fatalf("delete status = %d body=%s", deleteRecorder.Code, deleteRecorder.Body.String()) + } + if capturedDelete.SubjectID != "local_admin" || + capturedDelete.SessionID != "session-1" || + capturedDelete.AssetID != "asset-1" { + t.Fatalf("captured delete = %+v", capturedDelete) + } +} + +func TestNetworkServerSessionAssetsRespectHTTPACL(t *testing.T) { + deniedACL := &ControlPlaneACL{ + mode: ACLModeStrict, + allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: {}}, + enabled: true, + } + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + t.Fatal("SaveSessionAsset should not be called when ACL denies upload") + return SessionAssetMeta{}, nil + }, + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + t.Fatal("OpenSessionAsset should not be called when ACL denies read") + return OpenSessionAssetResult{}, nil + }, + deleteAssetFn: func(context.Context, DeleteSessionAssetInput) error { + t.Fatal("DeleteSessionAsset should not be called when ACL denies delete") + return nil + }, + } + server := &NetworkServer{ + authenticator: staticTokenAuthenticator{token: "gateway-token"}, + acl: deniedACL, + metrics: NewGatewayMetrics(), + } + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusForbidden { + t.Fatalf("upload status = %d body=%s, want %d", uploadRecorder.Code, uploadRecorder.Body.String(), http.StatusForbidden) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusForbidden { + t.Fatalf("read status = %d body=%s, want %d", readRecorder.Code, readRecorder.Body.String(), http.StatusForbidden) + } + + deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil) + deleteRequest.Header.Set("Authorization", "Bearer gateway-token") + deleteRecorder := httptest.NewRecorder() + handler.ServeHTTP(deleteRecorder, deleteRequest) + if deleteRecorder.Code != http.StatusForbidden { + t.Fatalf("delete status = %d body=%s, want %d", deleteRecorder.Code, deleteRecorder.Body.String(), http.StatusForbidden) + } +} + +func TestNetworkServerSessionAssetWorkspaceHeader(t *testing.T) { + payload := gatewayMinimalPNGBytes() + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("upload workspace hash = %q, want workspace-b", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(payload)), + }, nil + }, + openAssetFn: func(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if got := WorkspaceHashFromContext(ctx); got != "workspace-b" { + t.Fatalf("read workspace hash = %q, want workspace-b", got) + } + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(payload)), + Meta: SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: input.AssetID, + MimeType: "image/png", + Size: int64(len(payload)), + }, + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", payload) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusOK { + t.Fatalf("upload status = %d body=%s", uploadRecorder.Code, uploadRecorder.Body.String()) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRequest.Header.Set(SessionAssetWorkspaceHeader, "workspace-b") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d body=%s", readRecorder.Code, readRecorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetWorkspaceHeaderEmptyFallback(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if got := WorkspaceHashFromContext(ctx); got != "" { + t.Fatalf("workspace hash = %q, want empty fallback", got) + } + return SessionAssetMeta{ + SessionID: input.SessionID, + AssetID: "asset-1", + MimeType: input.MimeType, + Size: int64(len(gatewayMinimalPNGBytes())), + }, nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", recorder.Code, recorder.Body.String()) + } +} + +func TestNetworkServerSessionAssetUploadErrors(t *testing.T) { + runtimePort := &runtimePortEventStub{} + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.withCORS(server.buildHandler(runtimePort)) + + t.Run("unauthorized", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnauthorized { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnauthorized) + } + }) + + t.Run("forbidden origin", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set("Origin", "http://evil.example") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusForbidden) + } + }) + + t.Run("non image", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "bad.txt", []byte("not an image")) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusUnsupportedMediaType { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusUnsupportedMediaType) + } + }) + + t.Run("empty file", func(t *testing.T) { + request := newSessionAssetUploadRequest(t, "session-1", "empty.png", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusBadRequest) + } + }) + + t.Run("oversized file", func(t *testing.T) { + request := newSessionAssetUploadRequest( + t, + "session-1", + "huge.png", + bytes.Repeat([]byte{0}, int(agentsession.MaxSessionAssetBytes)+1), + ) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusRequestEntityTooLarge { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusRequestEntityTooLarge) + } + }) + + t.Run("workspace not found", func(t *testing.T) { + runtimePort := &runtimePortEventStub{ + saveAssetFn: func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, fmt.Errorf("%w: workspace missing not found", ErrRuntimeResourceNotFound) + }, + } + handler := server.withCORS(server.buildHandler(runtimePort)) + request := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + request.Header.Set("Authorization", "Bearer gateway-token") + request.Header.Set(SessionAssetWorkspaceHeader, "missing") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } + if !strings.Contains(recorder.Body.String(), "workspace not found") { + t.Fatalf("body = %s, want workspace not found", recorder.Body.String()) + } + }) +} + +func TestNetworkServerSessionAssetReadNotFound(t *testing.T) { + runtimePort := &runtimePortEventStub{ + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, ErrRuntimeResourceNotFound + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/missing", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d", recorder.Code, http.StatusNotFound) + } +} + +func TestNetworkServerSessionAssetDeleteMissingIsIdempotent(t *testing.T) { + called := false + runtimePort := &runtimePortEventStub{ + deleteAssetFn: func(_ context.Context, input DeleteSessionAssetInput) error { + called = true + if input.SubjectID != "local_admin" || input.SessionID != "session-1" || input.AssetID != "missing" { + t.Fatalf("delete input = %+v, want subject/session/missing", input) + } + return nil + }, + } + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + request := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/missing", nil) + request.Header.Set("Authorization", "Bearer gateway-token") + recorder := httptest.NewRecorder() + handler.ServeHTTP(recorder, request) + if recorder.Code != http.StatusOK { + t.Fatalf("status = %d body=%s, want %d", recorder.Code, recorder.Body.String(), http.StatusOK) + } + if !called { + t.Fatal("DeleteSessionAsset was not called") + } +} + +// TestNetworkServerSessionAssetACLIndependent 验证 GET 和 DELETE 的 ACL 检查相互独立: +// 只允许 read 时 GET 通过但 DELETE 被拒;只允许 delete 时 DELETE 通过但 GET 被拒。 +func TestNetworkServerSessionAssetACLIndependent(t *testing.T) { + t.Run("read allowed delete denied", func(t *testing.T) { + readOnlyACL := &ControlPlaneACL{ + mode: ACLModeStrict, + allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: normalizedMethodSet(sessionAssetReadMethod)}, + enabled: true, + } + runtimePort := &runtimePortEventStub{ + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{ + Reader: io.NopCloser(bytes.NewReader(gatewayMinimalPNGBytes())), + Meta: SessionAssetMeta{SessionID: "session-1", AssetID: "asset-1", MimeType: "image/png"}, + }, nil + }, + deleteAssetFn: func(context.Context, DeleteSessionAssetInput) error { + t.Fatal("DeleteSessionAsset should not be called when ACL denies delete") + return nil + }, + } + server := &NetworkServer{ + authenticator: staticTokenAuthenticator{token: "gateway-token"}, + acl: readOnlyACL, + metrics: NewGatewayMetrics(), + } + handler := server.buildHandler(runtimePort) + + // GET should succeed (read allowed) + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusOK { + t.Fatalf("read status = %d, want %d", readRecorder.Code, http.StatusOK) + } + + // DELETE should be forbidden (delete denied) + deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil) + deleteRequest.Header.Set("Authorization", "Bearer gateway-token") + deleteRecorder := httptest.NewRecorder() + handler.ServeHTTP(deleteRecorder, deleteRequest) + if deleteRecorder.Code != http.StatusForbidden { + t.Fatalf("delete status = %d, want %d", deleteRecorder.Code, http.StatusForbidden) + } + }) + + t.Run("delete allowed read denied", func(t *testing.T) { + deleteOnlyACL := &ControlPlaneACL{ + mode: ACLModeStrict, + allow: map[RequestSource]map[string]struct{}{RequestSourceHTTP: normalizedMethodSet(sessionAssetDeleteMethod)}, + enabled: true, + } + runtimePort := &runtimePortEventStub{ + openAssetFn: func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) { + t.Fatal("OpenSessionAsset should not be called when ACL denies read") + return OpenSessionAssetResult{}, nil + }, + deleteAssetFn: func(context.Context, DeleteSessionAssetInput) error { + return nil + }, + } + server := &NetworkServer{ + authenticator: staticTokenAuthenticator{token: "gateway-token"}, + acl: deleteOnlyACL, + metrics: NewGatewayMetrics(), + } + handler := server.buildHandler(runtimePort) + + // DELETE should succeed (delete allowed) + deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil) + deleteRequest.Header.Set("Authorization", "Bearer gateway-token") + deleteRecorder := httptest.NewRecorder() + handler.ServeHTTP(deleteRecorder, deleteRequest) + if deleteRecorder.Code != http.StatusOK { + t.Fatalf("delete status = %d, want %d", deleteRecorder.Code, http.StatusOK) + } + + // GET should be forbidden (read denied) + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusForbidden { + t.Fatalf("read status = %d, want %d", readRecorder.Code, http.StatusForbidden) + } + }) +} + +func TestNetworkServerSessionAssetsRequireAssetPort(t *testing.T) { + runtimePort := &runtimePortWithoutSessionAsset{RuntimePort: &runtimePortEventStub{}} + server := &NetworkServer{authenticator: staticTokenAuthenticator{token: "gateway-token"}} + handler := server.buildHandler(runtimePort) + + uploadRequest := newSessionAssetUploadRequest(t, "session-1", "a.png", gatewayMinimalPNGBytes()) + uploadRequest.Header.Set("Authorization", "Bearer gateway-token") + uploadRecorder := httptest.NewRecorder() + handler.ServeHTTP(uploadRecorder, uploadRequest) + if uploadRecorder.Code != http.StatusServiceUnavailable { + t.Fatalf("upload status = %d body=%s, want %d", uploadRecorder.Code, uploadRecorder.Body.String(), http.StatusServiceUnavailable) + } + + readRequest := httptest.NewRequest(http.MethodGet, "/api/session-assets/session-1/asset-1", nil) + readRequest.Header.Set("Authorization", "Bearer gateway-token") + readRecorder := httptest.NewRecorder() + handler.ServeHTTP(readRecorder, readRequest) + if readRecorder.Code != http.StatusServiceUnavailable { + t.Fatalf("read status = %d body=%s, want %d", readRecorder.Code, readRecorder.Body.String(), http.StatusServiceUnavailable) + } + + deleteRequest := httptest.NewRequest(http.MethodDelete, "/api/session-assets/session-1/asset-1", nil) + deleteRequest.Header.Set("Authorization", "Bearer gateway-token") + deleteRecorder := httptest.NewRecorder() + handler.ServeHTTP(deleteRecorder, deleteRequest) + if deleteRecorder.Code != http.StatusServiceUnavailable { + t.Fatalf("delete status = %d body=%s, want %d", deleteRecorder.Code, deleteRecorder.Body.String(), http.StatusServiceUnavailable) + } +} + func TestNetworkServerWebSocketAndSSEPing(t *testing.T) { server := newTestNetworkServer(t, NetworkServerOptions{}) testContext, cancel := context.WithCancel(context.Background()) @@ -1322,6 +1788,45 @@ type noFlushResponseWriter struct { body strings.Builder } +func newSessionAssetUploadRequest(t *testing.T, sessionID, fileName string, payload []byte) *http.Request { + t.Helper() + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + if sessionID != "" { + if err := writer.WriteField("session_id", sessionID); err != nil { + t.Fatalf("write session_id field: %v", err) + } + } + part, err := writer.CreateFormFile("file", fileName) + if err != nil { + t.Fatalf("create file part: %v", err) + } + if _, err := part.Write(payload); err != nil { + t.Fatalf("write file part: %v", err) + } + if err := writer.Close(); err != nil { + t.Fatalf("close multipart writer: %v", err) + } + request := httptest.NewRequest(http.MethodPost, "/api/session-assets", &body) + request.Header.Set("Content-Type", writer.FormDataContentType()) + return request +} + +func gatewayMinimalPNGBytes() []byte { + return []byte{ + 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, + 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, + 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, + 0x08, 0x06, 0x00, 0x00, 0x00, 0x1f, 0x15, 0xc4, + 0x89, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x44, 0x41, + 0x54, 0x78, 0x9c, 0x63, 0xf8, 0xcf, 0xc0, 0x00, + 0x00, 0x03, 0x01, 0x01, 0x00, 0xc9, 0xfe, 0x92, + 0xef, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, + 0x44, 0xae, 0x42, 0x60, 0x82, + } +} + type staticTokenAuthenticator struct { token string } diff --git a/internal/gateway/protocol/jsonrpc.go b/internal/gateway/protocol/jsonrpc.go index 41e1193e5..71ab1d6ed 100644 --- a/internal/gateway/protocol/jsonrpc.go +++ b/internal/gateway/protocol/jsonrpc.go @@ -63,6 +63,8 @@ const ( MethodGatewayCheckpointDiff = "checkpoint.diff" // MethodGatewayResolvePermission 表示提交权限审批决策。 MethodGatewayResolvePermission = "gateway.resolvePermission" + // MethodGatewayApprovePlan 表示批准当前 draft 计划 revision。 + MethodGatewayApprovePlan = "gateway.approvePlan" // MethodGatewayUserQuestionAnswer 表示提交 ask_user 回答。 MethodGatewayUserQuestionAnswer = "gateway.userQuestionAnswer" // MethodGatewayDeleteSession 表示删除/归档会话。 @@ -138,6 +140,8 @@ const ( GatewayCodeUnsupportedAction = "unsupported_action" // GatewayCodeInternalError 表示网关内部错误。 GatewayCodeInternalError = "internal_error" + // GatewayCodeMaxTurnExceeded 表示 runtime 达到单次运行最大轮数后受控停止。 + GatewayCodeMaxTurnExceeded = "max_turn_exceeded" // GatewayCodeTimeout 表示网关处理请求时发生超时。 GatewayCodeTimeout = "timeout" // GatewayCodeUnsafePath 表示路径存在安全风险。 @@ -246,6 +250,7 @@ const ( // RunInputMedia 用于承载 gateway.run 中图片分片的媒体元数据。 type RunInputMedia struct { URI string `json:"uri"` + AssetID string `json:"asset_id,omitempty"` MimeType string `json:"mime_type"` FileName string `json:"file_name,omitempty"` } @@ -366,6 +371,13 @@ type ResolvePermissionParams struct { Decision string `json:"decision"` } +// ApprovePlanParams 表示 gateway.approvePlan 参数。 +type ApprovePlanParams struct { + SessionID string `json:"session_id"` + PlanID string `json:"plan_id"` + Revision int `json:"revision"` +} + // UserQuestionAnswerParams 表示 gateway.userQuestionAnswer 参数。 type UserQuestionAnswerParams struct { RequestID string `json:"request_id"` @@ -779,6 +791,15 @@ func NormalizeJSONRPCRequest(request JSONRPCRequest) (NormalizedRequest, *JSONRP normalized.Action = "resolve_permission" normalized.Payload = params return normalized, nil + case MethodGatewayApprovePlan: + params, parseErr := decodeApprovePlanParams(request.Params) + if parseErr != nil { + return normalized, parseErr + } + normalized.Action = "approve_plan" + normalized.SessionID = strings.TrimSpace(params.SessionID) + normalized.Payload = params + return normalized, nil case MethodGatewayUserQuestionAnswer: params, parseErr := decodeUserQuestionAnswerParams(request.Params) if parseErr != nil { @@ -1183,7 +1204,8 @@ func MapGatewayCodeToJSONRPCCode(gatewayCode string) int { GatewayCodeUnsafePath, GatewayCodeUnauthorized, GatewayCodeAccessDenied, - GatewayCodeResourceNotFound: + GatewayCodeResourceNotFound, + GatewayCodeMaxTurnExceeded: return JSONRPCCodeInvalidParams case GatewayCodeInternalError: return JSONRPCCodeInternalError @@ -1381,6 +1403,7 @@ func decodeRunParams(raw json.RawMessage) (RunParams, *JSONRPCError) { p.InputParts[i].Text = strings.TrimSpace(p.InputParts[i].Text) if m := p.InputParts[i].Media; m != nil { m.URI = strings.TrimSpace(m.URI) + m.AssetID = strings.TrimSpace(m.AssetID) m.MimeType = strings.TrimSpace(m.MimeType) m.FileName = strings.TrimSpace(m.FileName) } @@ -1514,6 +1537,24 @@ func decodeResolvePermissionParams(raw json.RawMessage) (ResolvePermissionParams }) } +// decodeApprovePlanParams 对 gateway.approvePlan 的 params 执行反序列化与字段校验。 +func decodeApprovePlanParams(raw json.RawMessage) (ApprovePlanParams, *JSONRPCError) { + return decodeParams(raw, "gateway.approvePlan", func(p *ApprovePlanParams) *JSONRPCError { + p.SessionID = strings.TrimSpace(p.SessionID) + p.PlanID = strings.TrimSpace(p.PlanID) + if p.SessionID == "" { + return NewJSONRPCError(JSONRPCCodeInvalidParams, "missing required field: params.session_id", GatewayCodeMissingRequiredField) + } + if p.PlanID == "" { + return NewJSONRPCError(JSONRPCCodeInvalidParams, "missing required field: params.plan_id", GatewayCodeMissingRequiredField) + } + if p.Revision <= 0 { + return NewJSONRPCError(JSONRPCCodeInvalidParams, "invalid field: params.revision", GatewayCodeInvalidAction) + } + return nil + }) +} + // decodeUserQuestionAnswerParams 对 gateway.userQuestionAnswer 的 params 执行反序列化与字段校验。 func decodeUserQuestionAnswerParams(raw json.RawMessage) (UserQuestionAnswerParams, *JSONRPCError) { return decodeParams(raw, "gateway.userQuestionAnswer", func(p *UserQuestionAnswerParams) *JSONRPCError { diff --git a/internal/gateway/protocol/jsonrpc_test.go b/internal/gateway/protocol/jsonrpc_test.go index 2d92c2776..c64211e6c 100644 --- a/internal/gateway/protocol/jsonrpc_test.go +++ b/internal/gateway/protocol/jsonrpc_test.go @@ -316,6 +316,71 @@ func TestNormalizeJSONRPCRequestCheckpointMethods(t *testing.T) { }) } +func TestNormalizeJSONRPCRequestApprovePlan(t *testing.T) { + t.Run("success", func(t *testing.T) { + normalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"approve-plan-1"`), + Method: MethodGatewayApprovePlan, + Params: json.RawMessage(`{"session_id":" session-1 ","plan_id":" plan-1 ","revision":2}`), + }) + if rpcErr != nil { + t.Fatalf("normalize approvePlan request: %v", rpcErr) + } + if normalized.Action != "approve_plan" { + t.Fatalf("action = %q, want %q", normalized.Action, "approve_plan") + } + if normalized.SessionID != "session-1" { + t.Fatalf("session_id = %q, want %q", normalized.SessionID, "session-1") + } + params, ok := normalized.Payload.(ApprovePlanParams) + if !ok { + t.Fatalf("payload type = %T, want ApprovePlanParams", normalized.Payload) + } + if params.PlanID != "plan-1" || params.Revision != 2 { + t.Fatalf("params = %#v, want plan-1 revision 2", params) + } + }) + + tests := []struct { + name string + params string + wantGatewayCode string + }{ + { + name: "missing session", + params: `{"session_id":" ","plan_id":"plan-1","revision":1}`, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "missing plan", + params: `{"session_id":"session-1","plan_id":" ","revision":1}`, + wantGatewayCode: GatewayCodeMissingRequiredField, + }, + { + name: "invalid revision", + params: `{"session_id":"session-1","plan_id":"plan-1","revision":0}`, + wantGatewayCode: GatewayCodeInvalidAction, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ + JSONRPC: JSONRPCVersion, + ID: json.RawMessage(`"approve-plan-invalid"`), + Method: MethodGatewayApprovePlan, + Params: json.RawMessage(tt.params), + }) + if rpcErr == nil || rpcErr.Code != JSONRPCCodeInvalidParams { + t.Fatalf("expected invalid params error, got %#v", rpcErr) + } + if gatewayCode := GatewayCodeFromJSONRPCError(rpcErr); gatewayCode != tt.wantGatewayCode { + t.Fatalf("gateway_code = %q, want %q", gatewayCode, tt.wantGatewayCode) + } + }) + } +} + func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { runRequest := JSONRPCRequest{ JSONRPC: JSONRPCVersion, @@ -328,7 +393,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { "workdir":" /tmp/work ", "input_parts":[ {"type":" TEXT ","text":" world "}, - {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}} + {"type":" image ","media":{"uri":" /tmp/a.png ","mime_type":" image/png ","file_name":" a.png "}}, + {"type":" image ","media":{"asset_id":" asset-1 ","mime_type":" image/webp "}} ] }`), } @@ -349,8 +415,8 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputText != "hello" { t.Fatalf("run input_text = %q, want %q", runParams.InputText, "hello") } - if len(runParams.InputParts) != 2 { - t.Fatalf("run input_parts len = %d, want 2", len(runParams.InputParts)) + if len(runParams.InputParts) != 3 { + t.Fatalf("run input_parts len = %d, want 3", len(runParams.InputParts)) } if runParams.InputParts[0].Type != "text" || runParams.InputParts[0].Text != "world" { t.Fatalf("run text part = %#v, want normalized text part", runParams.InputParts[0]) @@ -361,6 +427,12 @@ func TestNormalizeJSONRPCRequestRuntimeMethods(t *testing.T) { if runParams.InputParts[1].Media.MimeType != "image/png" || runParams.InputParts[1].Media.FileName != "a.png" { t.Fatalf("run image media = %#v, want trimmed mime/file_name", runParams.InputParts[1].Media) } + if runParams.InputParts[2].Type != "image" || + runParams.InputParts[2].Media == nil || + runParams.InputParts[2].Media.AssetID != "asset-1" || + runParams.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("run image asset media = %#v, want trimmed asset_id/mime", runParams.InputParts[2]) + } compactNormalized, rpcErr := NormalizeJSONRPCRequest(JSONRPCRequest{ JSONRPC: JSONRPCVersion, diff --git a/internal/gateway/registry.go b/internal/gateway/registry.go index 853a665f3..95e129dee 100644 --- a/internal/gateway/registry.go +++ b/internal/gateway/registry.go @@ -59,6 +59,7 @@ func (r *ActionRegistry) initCore() { r.core[FrameActionListSessionTodos] = handleListSessionTodosFrame r.core[FrameActionGetRuntimeSnapshot] = handleGetRuntimeSnapshotFrame r.core[FrameActionResolvePermission] = handleResolvePermissionFrame + r.core[FrameActionApprovePlan] = handleApprovePlanFrame r.core[FrameActionUserQuestionAnswer] = handleUserQuestionAnswerFrame r.core[FrameActionDeleteSession] = handleDeleteSessionFrame r.core[FrameActionRenameSession] = handleRenameSessionFrame diff --git a/internal/gateway/rpc_dispatch.go b/internal/gateway/rpc_dispatch.go index 0360b000d..33639ff66 100644 --- a/internal/gateway/rpc_dispatch.go +++ b/internal/gateway/rpc_dispatch.go @@ -344,6 +344,7 @@ func convertProtocolRunInputParts(parts []protocol.RunInputPart) []InputPart { if part.Media != nil { convertedPart.Media = &Media{ URI: strings.TrimSpace(part.Media.URI), + AssetID: strings.TrimSpace(part.Media.AssetID), MimeType: strings.TrimSpace(part.Media.MimeType), FileName: strings.TrimSpace(part.Media.FileName), } @@ -365,6 +366,7 @@ func requiresSession(action FrameAction) bool { FrameActionActivateSessionSkill, FrameActionDeactivateSessionSkill, FrameActionListSessionSkills, + FrameActionApprovePlan, FrameActionDeleteSession, FrameActionRenameSession, FrameActionSetSessionModel, diff --git a/internal/gateway/rpc_dispatch_test.go b/internal/gateway/rpc_dispatch_test.go index 1f2e8e57f..737851253 100644 --- a/internal/gateway/rpc_dispatch_test.go +++ b/internal/gateway/rpc_dispatch_test.go @@ -25,6 +25,8 @@ type rpcRunCaptureRuntimeStub struct { deactivateSkillFn func(ctx context.Context, input SessionSkillMutationInput) error listSessionSkillsFn func(ctx context.Context, input ListSessionSkillsInput) ([]SessionSkillState, error) listAvailableFn func(ctx context.Context, input ListAvailableSkillsInput) ([]AvailableSkillState, error) + approvePlanInput ApprovePlanInput + approvePlanFn func(ctx context.Context, input ApprovePlanInput) (ApprovePlanResult, error) loadSessionFn func(ctx context.Context, input LoadSessionInput) (Session, error) listProvidersFn func(ctx context.Context, input ListProvidersInput) ([]ProviderOption, error) createProviderFn func(ctx context.Context, input CreateProviderInput) (ProviderSelectionResult, error) @@ -115,6 +117,17 @@ func (s *rpcRunCaptureRuntimeStub) ResolvePermission(_ context.Context, _ Permis return nil } +func (s *rpcRunCaptureRuntimeStub) ApprovePlan( + ctx context.Context, + input ApprovePlanInput, +) (ApprovePlanResult, error) { + s.approvePlanInput = input + if s.approvePlanFn != nil { + return s.approvePlanFn(ctx, input) + } + return ApprovePlanResult{}, nil +} + func (s *rpcRunCaptureRuntimeStub) ResolveUserQuestion(_ context.Context, _ UserQuestionAnswerInput) error { return nil } @@ -221,6 +234,14 @@ func (s *rpcRunCaptureRuntimeStub) CreateSession(ctx context.Context, input Crea return s.createSessionID, nil } +func (s *rpcRunCaptureRuntimeStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} + +func (s *rpcRunCaptureRuntimeStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} + func (s *rpcRunCaptureRuntimeStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } @@ -608,6 +629,83 @@ func TestDispatchRPCRequestResolvePermissionDoesNotRequireSession(t *testing.T) } } +func TestDispatchRPCRequestApprovePlanAllowedForAuthenticatedWebSocket(t *testing.T) { + authState := NewConnectionAuthState() + authState.MarkAuthenticated("local_admin") + ctx := WithRequestSource(context.Background(), RequestSourceWS) + ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) + ctx = WithConnectionAuthState(ctx, authState) + + runtimeStub := &rpcRunCaptureRuntimeStub{ + approvePlanFn: func(_ context.Context, input ApprovePlanInput) (ApprovePlanResult, error) { + if input.SubjectID != "local_admin" { + t.Fatalf("subject_id = %q, want %q", input.SubjectID, "local_admin") + } + if input.SessionID != "session-1" || input.PlanID != "plan-1" || input.Revision != 2 { + t.Fatalf("approve input = %#v", input) + } + return ApprovePlanResult{ + PlanID: input.PlanID, + Revision: input.Revision, + Status: "approved", + }, nil + }, + } + + response := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-approve-plan"`), + Method: protocol.MethodGatewayApprovePlan, + Params: json.RawMessage(`{"session_id":"session-1","plan_id":"plan-1","revision":2}`), + }, runtimeStub) + if response.Error != nil { + t.Fatalf("approvePlan should pass strict WS ACL, got error: %+v", response.Error) + } + + frame, err := decodeJSONRPCResultFrame(response) + if err != nil { + t.Fatalf("decode approvePlan result frame: %v", err) + } + if frame.Action != FrameActionApprovePlan { + t.Fatalf("response action = %q, want %q", frame.Action, FrameActionApprovePlan) + } + payload, ok := frame.Payload.(map[string]any) + if !ok { + t.Fatalf("payload type = %T, want map[string]any", frame.Payload) + } + if payload["plan_id"] != "plan-1" || payload["status"] != "approved" || payload["revision"] != float64(2) { + t.Fatalf("payload = %#v, want approved plan result", payload) + } + if runtimeStub.approvePlanInput.PlanID != "plan-1" { + t.Fatalf("approvePlan was not called, captured input = %#v", runtimeStub.approvePlanInput) + } +} + +func TestDispatchRPCRequestApprovePlanInvalidRuntimeAction(t *testing.T) { + authState := NewConnectionAuthState() + authState.MarkAuthenticated("local_admin") + ctx := WithRequestSource(context.Background(), RequestSourceWS) + ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) + ctx = WithConnectionAuthState(ctx, authState) + + response := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-approve-plan-invalid"`), + Method: protocol.MethodGatewayApprovePlan, + Params: json.RawMessage(`{"session_id":"session-1","plan_id":"plan-1","revision":2}`), + }, &rpcRunCaptureRuntimeStub{ + approvePlanFn: func(_ context.Context, _ ApprovePlanInput) (ApprovePlanResult, error) { + return ApprovePlanResult{}, ErrRuntimeInvalidAction + }, + }) + if response.Error == nil { + t.Fatal("approvePlan invalid action should return JSON-RPC error") + } + if response.Error.Data == nil || response.Error.Data.GatewayCode != ErrorCodeInvalidAction.String() { + t.Fatalf("approvePlan error = %#v, want invalid_action", response.Error) + } +} + func TestDispatchRPCRequestExecuteSystemToolDoesNotRequireSession(t *testing.T) { ctx := WithRequestSource(context.Background(), RequestSourceIPC) ctx = WithRequestACL(ctx, NewStrictControlPlaneACL()) @@ -1040,6 +1138,12 @@ func (s *runtimePortOnlyStub) GetRuntimeSnapshot(_ context.Context, _ GetRuntime func (s *runtimePortOnlyStub) CreateSession(_ context.Context, _ CreateSessionInput) (string, error) { return "", nil } +func (s *runtimePortOnlyStub) SaveSessionAsset(_ context.Context, _ SaveSessionAssetInput) (SessionAssetMeta, error) { + return SessionAssetMeta{}, nil +} +func (s *runtimePortOnlyStub) OpenSessionAsset(_ context.Context, _ OpenSessionAssetInput) (OpenSessionAssetResult, error) { + return OpenSessionAssetResult{}, nil +} func (s *runtimePortOnlyStub) DeleteSession(_ context.Context, _ DeleteSessionInput) (bool, error) { return false, nil } @@ -1118,7 +1222,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { "session_id":"session-run-1", "input_parts":[ {"type":"text","text":"hello world"}, - {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}} + {"type":"image","media":{"uri":"C:/tmp/pic.png","mime_type":"image/png"}}, + {"type":"image","media":{"asset_id":"asset-1","mime_type":"image/webp"}} ] }`), }, runtimeStub) @@ -1139,8 +1244,8 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.RunID != "req-run-hydrate" { t.Fatalf("runtime run run_id = %q, want %q", captured.RunID, "req-run-hydrate") } - if len(captured.InputParts) != 2 { - t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 2) + if len(captured.InputParts) != 3 { + t.Fatalf("runtime run input_parts len = %d, want %d", len(captured.InputParts), 3) } if captured.InputParts[0].Type != InputPartTypeText { t.Fatalf("runtime text part type = %q, want %q", captured.InputParts[0].Type, InputPartTypeText) @@ -1151,6 +1256,11 @@ func TestDispatchRPCRequestRunHydratesInputPartsAndFallbackRunID(t *testing.T) { if captured.InputParts[1].Media == nil || captured.InputParts[1].Media.URI != "C:/tmp/pic.png" { t.Fatalf("runtime image media = %#v, want uri %q", captured.InputParts[1].Media, "C:/tmp/pic.png") } + if captured.InputParts[2].Media == nil || + captured.InputParts[2].Media.AssetID != "asset-1" || + captured.InputParts[2].Media.MimeType != "image/webp" { + t.Fatalf("runtime image asset media = %#v, want asset_id", captured.InputParts[2].Media) + } } func TestDispatchRPCRequest_DenyCrossSubjectLoadSession(t *testing.T) { @@ -1347,6 +1457,16 @@ func TestDispatchRPCRequestMetricsGrowForTUIMethodSequence(t *testing.T) { t.Fatalf("compact response error: %+v", compact.Error) } + approvePlan := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ + JSONRPC: protocol.JSONRPCVersion, + ID: json.RawMessage(`"req-approve-tui"`), + Method: protocol.MethodGatewayApprovePlan, + Params: json.RawMessage(`{"session_id":"session-tui","plan_id":"plan-tui","revision":1}`), + }, &rpcRunCaptureRuntimeStub{}) + if approvePlan.Error != nil { + t.Fatalf("approvePlan response error: %+v", approvePlan.Error) + } + listSessions := dispatchRPCRequest(ctx, protocol.JSONRPCRequest{ JSONRPC: protocol.JSONRPCVersion, ID: json.RawMessage(`"req-list-tui"`), @@ -1367,6 +1487,9 @@ func TestDispatchRPCRequestMetricsGrowForTUIMethodSequence(t *testing.T) { if snapshot["ipc|gateway.compact|ok"] == 0 { t.Fatalf("expected compact metric to grow, snapshot=%#v", snapshot) } + if snapshot["ipc|gateway.approveplan|ok"] == 0 { + t.Fatalf("expected approvePlan metric to grow, snapshot=%#v", snapshot) + } if snapshot["ipc|gateway.listsessions|ok"] == 0 { t.Fatalf("expected listSessions metric to grow, snapshot=%#v", snapshot) } diff --git a/internal/gateway/runtime_errors.go b/internal/gateway/runtime_errors.go index df14303ee..d1045cacd 100644 --- a/internal/gateway/runtime_errors.go +++ b/internal/gateway/runtime_errors.go @@ -1,10 +1,55 @@ package gateway -import "errors" +import ( + "errors" + "strings" +) var ( // ErrRuntimeAccessDenied 表示运行时拒绝当前主体访问目标资源。 ErrRuntimeAccessDenied = errors.New("runtime access denied") // ErrRuntimeResourceNotFound 表示运行时未找到目标资源。 ErrRuntimeResourceNotFound = errors.New("runtime resource not found") + // ErrRuntimeUnavailable 表示运行时或其可选下游能力不可用。 + ErrRuntimeUnavailable = errors.New("runtime unavailable") + // ErrRuntimeInvalidAction 表示运行时拒绝了语义非法或已过期的动作。 + ErrRuntimeInvalidAction = errors.New("runtime invalid action") + // ErrRuntimeMaxTurnExceeded 表示运行时达到 runtime.max_turns 后受控停止。 + ErrRuntimeMaxTurnExceeded = errors.New("runtime max turn exceeded") ) + +// RuntimeMaxTurnExceededError 携带 runtime 原始 max_turns 停止说明,供 Gateway 对外展示。 +type RuntimeMaxTurnExceededError struct { + Detail string +} + +// Error 返回可展示的 max_turns 停止说明。 +func (e RuntimeMaxTurnExceededError) Error() string { + detail := strings.TrimSpace(e.Detail) + if detail != "" { + return detail + } + return ErrRuntimeMaxTurnExceeded.Error() +} + +// Unwrap 保留稳定哨兵错误,便于 errors.Is 做语义判断。 +func (e RuntimeMaxTurnExceededError) Unwrap() error { + return ErrRuntimeMaxTurnExceeded +} + +// NewRuntimeMaxTurnExceededError 创建带细节的 max_turns 受控停止错误。 +func NewRuntimeMaxTurnExceededError(detail string) error { + return RuntimeMaxTurnExceededError{Detail: detail} +} + +// RuntimeMaxTurnExceededDetail 提取 max_turns 受控停止错误中的展示文本。 +func RuntimeMaxTurnExceededDetail(err error) string { + var target RuntimeMaxTurnExceededError + if errors.As(err, &target) { + return target.Error() + } + if errors.Is(err, ErrRuntimeMaxTurnExceeded) { + return ErrRuntimeMaxTurnExceeded.Error() + } + return "" +} diff --git a/internal/gateway/security.go b/internal/gateway/security.go index 2cb1406d9..2270013e2 100644 --- a/internal/gateway/security.go +++ b/internal/gateway/security.go @@ -4,7 +4,12 @@ import ( "strings" ) -const pingMethod = "gateway.ping" +const ( + pingMethod = "gateway.ping" + sessionAssetUploadMethod = "gateway.sessionAssetUpload" + sessionAssetReadMethod = "gateway.sessionAssetRead" + sessionAssetDeleteMethod = "gateway.sessionAssetDelete" +) // RequestSource 表示控制面请求来源,用于 ACL 与日志分类。 type RequestSource string @@ -72,6 +77,7 @@ func fullControlPlaneMethods() map[string]struct{} { "checkpoint.undoRestore", "checkpoint.diff", "gateway.resolvePermission", + "gateway.approvePlan", "gateway.userQuestionAnswer", "gateway.user_question_answer", "gateway.deleteSession", @@ -97,6 +103,9 @@ func fullControlPlaneMethods() map[string]struct{} { "gateway.renameWorkspace", "gateway.deleteWorkspace", "wake.openUrl", + sessionAssetUploadMethod, + sessionAssetReadMethod, + sessionAssetDeleteMethod, } return normalizedMethodSet(methods...) } diff --git a/internal/gateway/security_test.go b/internal/gateway/security_test.go index 183c7c81b..6c8b38253 100644 --- a/internal/gateway/security_test.go +++ b/internal/gateway/security_test.go @@ -39,9 +39,18 @@ func TestStrictACLAllowlist(t *testing.T) { {source: RequestSourceHTTP, method: "checkpoint.restore", want: true}, {source: RequestSourceHTTP, method: "checkpoint.undoRestore", want: true}, {source: RequestSourceHTTP, method: "checkpoint.diff", want: true}, + {source: RequestSourceIPC, method: "gateway.approvePlan", want: true}, + {source: RequestSourceHTTP, method: "gateway.approvePlan", want: true}, + {source: RequestSourceWS, method: "gateway.approvePlan", want: true}, + {source: RequestSourceSSE, method: "gateway.approvePlan", want: false}, {source: RequestSourceHTTP, method: "gateway.userQuestionAnswer", want: true}, {source: RequestSourceHTTP, method: "gateway.user_question_answer", want: true}, + {source: RequestSourceHTTP, method: sessionAssetUploadMethod, want: true}, + {source: RequestSourceHTTP, method: sessionAssetReadMethod, want: true}, + {source: RequestSourceHTTP, method: sessionAssetDeleteMethod, want: true}, + {source: RequestSourceSSE, method: sessionAssetReadMethod, want: false}, {source: RequestSourceUnknown, method: "gateway.ping", want: false}, + {source: RequestSourceUnknown, method: "gateway.approvePlan", want: false}, } for _, tc := range cases { assertACLAllowed(t, acl, tc.source, tc.method, tc.want) diff --git a/internal/gateway/server_test.go b/internal/gateway/server_test.go index a61d8eca9..f98053d6b 100644 --- a/internal/gateway/server_test.go +++ b/internal/gateway/server_test.go @@ -367,7 +367,14 @@ func TestServerHandleConnectionAuthenticateFlow(t *testing.T) { } type runtimePortEventStub struct { - events <-chan RuntimeEvent + events <-chan RuntimeEvent + saveAssetFn func(context.Context, SaveSessionAssetInput) (SessionAssetMeta, error) + openAssetFn func(context.Context, OpenSessionAssetInput) (OpenSessionAssetResult, error) + deleteAssetFn func(context.Context, DeleteSessionAssetInput) error +} + +type runtimePortWithoutSessionAsset struct { + RuntimePort } func (s *runtimePortEventStub) Run(_ context.Context, _ RunInput) error { @@ -467,6 +474,27 @@ func (s *runtimePortEventStub) CreateSession(_ context.Context, _ CreateSessionI return "", nil } +func (s *runtimePortEventStub) SaveSessionAsset(ctx context.Context, input SaveSessionAssetInput) (SessionAssetMeta, error) { + if s.saveAssetFn != nil { + return s.saveAssetFn(ctx, input) + } + return SessionAssetMeta{}, nil +} + +func (s *runtimePortEventStub) OpenSessionAsset(ctx context.Context, input OpenSessionAssetInput) (OpenSessionAssetResult, error) { + if s.openAssetFn != nil { + return s.openAssetFn(ctx, input) + } + return OpenSessionAssetResult{}, nil +} + +func (s *runtimePortEventStub) DeleteSessionAsset(ctx context.Context, input DeleteSessionAssetInput) error { + if s.deleteAssetFn != nil { + return s.deleteAssetFn(ctx, input) + } + return nil +} + func (s *runtimePortEventStub) ListSessionTodos(_ context.Context, _ ListSessionTodosInput) (TodoSnapshot, error) { return TodoSnapshot{}, nil } diff --git a/internal/gateway/static_files.go b/internal/gateway/static_files.go index 936f7c78f..12aeb4c17 100644 --- a/internal/gateway/static_files.go +++ b/internal/gateway/static_files.go @@ -16,6 +16,7 @@ var knownAPIPrefixes = map[string]bool{ "/healthz": true, "/version": true, "/rpc": true, + "/api": true, "/ws": true, "/sse": true, "/metrics": true, diff --git a/internal/gateway/types.go b/internal/gateway/types.go index 1b3e4a8f6..e2a98a938 100644 --- a/internal/gateway/types.go +++ b/internal/gateway/types.go @@ -58,6 +58,8 @@ const ( FrameActionGetRuntimeSnapshot FrameAction = "runtime_snapshot_get" // FrameActionResolvePermission 表示提交一次权限审批决策。 FrameActionResolvePermission FrameAction = "resolve_permission" + // FrameActionApprovePlan 表示批准当前 draft 计划 revision。 + FrameActionApprovePlan FrameAction = "approve_plan" // FrameActionUserQuestionAnswer 表示提交一次 ask_user 回答。 FrameActionUserQuestionAnswer FrameAction = "user_question_answer" // FrameActionDeleteSession 表示删除/归档会话。 @@ -134,6 +136,8 @@ const ( type Media struct { // URI 是媒体资源地址。 URI string `json:"uri"` + // AssetID 是已保存的 session asset 标识。 + AssetID string `json:"asset_id,omitempty"` // MimeType 是媒体 MIME 类型。 MimeType string `json:"mime_type"` // FileName 是媒体文件名。 diff --git a/internal/gateway/validate.go b/internal/gateway/validate.go index ebd979704..4684a323c 100644 --- a/internal/gateway/validate.go +++ b/internal/gateway/validate.go @@ -98,6 +98,8 @@ func validateRequestFrame(frame MessageFrame) *FrameError { return nil case FrameActionResolvePermission: return validateResolvePermissionFrame(frame) + case FrameActionApprovePlan: + return validateApprovePlanFrame(frame) case FrameActionUserQuestionAnswer: return validateUserQuestionAnswerFrame(frame) case FrameActionRestoreCheckpoint, @@ -179,6 +181,42 @@ func decodePermissionResolutionInput(payload any) (PermissionResolutionInput, er return input, nil } +// validateApprovePlanFrame 校验 approve_plan 动作所需字段。 +func validateApprovePlanFrame(frame MessageFrame) *FrameError { + if frame.Payload == nil { + return NewMissingRequiredFieldError("payload") + } + + input, err := decodeApprovePlanPayload(frame.Payload) + if err != nil { + return err + } + if strings.TrimSpace(input.SessionID) == "" { + return NewMissingRequiredFieldError("payload.session_id") + } + if strings.TrimSpace(input.PlanID) == "" { + return NewMissingRequiredFieldError("payload.plan_id") + } + if input.Revision <= 0 { + return NewFrameError(ErrorCodeInvalidAction, "invalid approve_plan revision") + } + + return nil +} + +// decodeApprovePlanPayload 将 payload 解析为批准计划输入。 +func decodeApprovePlanPayload(payload any) (ApprovePlanInput, *FrameError) { + var params protocol.ApprovePlanParams + if err := decodePayload(payload, ¶ms); err != nil { + return ApprovePlanInput{}, NewFrameError(ErrorCodeInvalidFrame, "invalid approve_plan payload") + } + return ApprovePlanInput{ + SessionID: strings.TrimSpace(params.SessionID), + PlanID: strings.TrimSpace(params.PlanID), + Revision: params.Revision, + }, nil +} + // decodeRenameSessionPayload 解析 renameSession 的负载参数。 func decodeRenameSessionPayload(payload any) (renameSessionParams, *FrameError) { switch typed := payload.(type) { @@ -547,8 +585,10 @@ func validateInputPart(part InputPart, index int) *FrameError { if part.Media == nil { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media") } - if strings.TrimSpace(part.Media.URI) == "" { - return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.uri") + hasURI := strings.TrimSpace(part.Media.URI) != "" + hasAssetID := strings.TrimSpace(part.Media.AssetID) != "" + if hasURI == hasAssetID { + return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires exactly one of media.uri or media.asset_id") } if strings.TrimSpace(part.Media.MimeType) == "" { return NewFrameError(ErrorCodeInvalidMultimodalPayload, "input_parts[image] requires media.mime_type") @@ -595,6 +635,7 @@ func isValidFrameAction(action FrameAction) bool { FrameActionListSessionTodos, FrameActionGetRuntimeSnapshot, FrameActionResolvePermission, + FrameActionApprovePlan, FrameActionDeleteSession, FrameActionRenameSession, FrameActionListFiles, diff --git a/internal/gateway/validate_test.go b/internal/gateway/validate_test.go index 4b6a0f38d..958e46b07 100644 --- a/internal/gateway/validate_test.go +++ b/internal/gateway/validate_test.go @@ -184,6 +184,76 @@ func TestValidateFrame_BasicRules(t *testing.T) { }, wantCode: ErrorCodeInvalidAction.String(), }, + { + name: "approve_plan valid payload", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 1, + }, + }, + wantNil: true, + }, + { + name: "approve_plan missing payload", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + }, + wantCode: ErrorCodeMissingRequiredField.String(), + wantField: "payload", + }, + { + name: "approve_plan missing session_id", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "plan_id": "plan-1", + "revision": 1, + }, + }, + wantCode: ErrorCodeMissingRequiredField.String(), + wantField: "payload.session_id", + }, + { + name: "approve_plan missing plan_id", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "revision": 1, + }, + }, + wantCode: ErrorCodeMissingRequiredField.String(), + wantField: "payload.plan_id", + }, + { + name: "approve_plan invalid revision", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: map[string]any{ + "session_id": "session-1", + "plan_id": "plan-1", + "revision": 0, + }, + }, + wantCode: ErrorCodeInvalidAction.String(), + }, + { + name: "approve_plan invalid payload shape", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionApprovePlan, + Payload: "bad-payload", + }, + wantCode: ErrorCodeInvalidFrame.String(), + }, { name: "event frame allows empty action", frame: MessageFrame{ @@ -759,6 +829,23 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantNil: true, }, + { + name: "valid image asset part", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantNil: true, + }, { name: "text part with empty text", frame: MessageFrame{ @@ -782,7 +869,7 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, { - name: "image part missing media.uri", + name: "image part missing media.uri and media.asset_id", frame: MessageFrame{ Type: FrameTypeRequest, Action: FrameActionRun, @@ -795,6 +882,24 @@ func TestValidateFrame_MultimodalPayloadRules(t *testing.T) { }, wantCode: ErrorCodeInvalidMultimodalPayload.String(), }, + { + name: "image part has both media.uri and media.asset_id", + frame: MessageFrame{ + Type: FrameTypeRequest, + Action: FrameActionRun, + InputParts: []InputPart{ + { + Type: InputPartTypeImage, + Media: &Media{ + URI: "file:///a.png", + AssetID: "asset-1", + MimeType: "image/png", + }, + }, + }, + }, + wantCode: ErrorCodeInvalidMultimodalPayload.String(), + }, { name: "image part missing media.mime_type", frame: MessageFrame{ diff --git a/internal/promptasset/assets_test.go b/internal/promptasset/assets_test.go index 4523e7faf..f625980f9 100644 --- a/internal/promptasset/assets_test.go +++ b/internal/promptasset/assets_test.go @@ -94,12 +94,21 @@ func TestPlanModePromptTemplates(t *testing.T) { strings.Contains(PlanModePrompt("plan"), "must not be empty") { t.Fatalf("expected plan prompt not to require execution todo ownership") } + if strings.Contains(PlanModePrompt("plan"), "Only output a JSON object") { + t.Fatalf("expected plan prompt not to require JSON-only output") + } + if !strings.Contains(PlanModePrompt("plan"), "inside an HTML comment") { + t.Fatalf("expected plan prompt to require machine-readable JSON in an HTML comment") + } if !strings.Contains(PlanModePrompt("plan"), "Do not create execution todos in plan mode") { t.Fatalf("expected plan prompt to keep todos in build execution") } if !strings.Contains(PlanModePrompt("build_execute"), "create current-run required todos") { t.Fatalf("expected build prompt to require direct-build todo bootstrap") } + if !strings.Contains(PlanModePrompt("build_execute"), "Todo State is attached as `None`") { + t.Fatalf("expected build prompt to bootstrap when Todo State is None") + } if !strings.Contains(PlanModePrompt("build_execute"), "simple conversational inputs") { t.Fatalf("expected build prompt to cover simple conversational completion") } diff --git a/internal/promptasset/templates/context/plan_mode_build_execute.md b/internal/promptasset/templates/context/plan_mode_build_execute.md index 2996a2c9a..d1acdc035 100644 --- a/internal/promptasset/templates/context/plan_mode_build_execute.md +++ b/internal/promptasset/templates/context/plan_mode_build_execute.md @@ -4,7 +4,7 @@ You are currently in build execution. - If a current plan summary is attached, use it as guidance by default. - If the summary is insufficient for the current task, consult the attached full plan view when available. - If no current plan is attached, continue using task state, todos, and the conversation context. -- If no Todo State is attached, create current-run required todos with `todo_write` before the first substantive tool call for project analysis, documentation writing, code changes, multi-step debugging, or verification work. +- If no Todo State is attached, or Todo State is attached as `None`, create current-run required todos with `todo_write` before the first substantive tool call for project analysis, documentation writing, code changes, multi-step debugging, or verification work. - Do not update or complete todo IDs that are not present in the current Todo State; create new current-run todos instead. - Small necessary deviations are allowed, but explain why they are needed. - Do not create or rewrite the current full plan in this stage. diff --git a/internal/promptasset/templates/context/plan_mode_plan.md b/internal/promptasset/templates/context/plan_mode_plan.md index 4cfd71427..b5c422e62 100644 --- a/internal/promptasset/templates/context/plan_mode_plan.md +++ b/internal/promptasset/templates/context/plan_mode_plan.md @@ -3,9 +3,9 @@ You are currently in the planning stage. - You may research, analyze, ask clarifying questions, and produce a plan. - Do not perform any write action in this stage. - Do not rewrite the current full plan unless the conversation clearly requires creating or replacing the plan itself. -- **If no Current Plan section is attached, your first priority is to produce a plan.** The user has entered planning mode expecting a structured plan. Research the codebase as needed, then output a complete `plan_spec` + `summary_candidate` JSON. Do not end the turn with only a conversational answer when there is no existing plan. +- **If no Current Plan section is attached, your first priority is to produce a plan.** The user has entered planning mode expecting a structured plan. Research the codebase as needed, then output a visible Markdown plan followed by one compact machine-readable JSON object containing `plan_spec` and `summary_candidate` inside an HTML comment. Do not end the turn with only a conversational answer when there is no existing plan. - If a Current Plan is already present, you may refine, replace, or discuss it. When the user asks a clarifying question or wants to explore options without committing to a new plan revision, you may answer conversationally without outputting planning JSON. -- Only output a JSON object containing `plan_spec` and `summary_candidate` when you are explicitly creating or rewriting the current full plan. +- When explicitly creating or rewriting the current full plan, output the visible plan as Markdown first, then append the machine-readable JSON inside an HTML comment, not in a fenced code block. - `plan_spec` must include `goal`, `steps`, `constraints`, and `open_questions`. - `plan_spec.todos` is optional legacy data. Do not create execution todos in plan mode; build mode will create and maintain runtime todos when implementation starts. - `summary_candidate` must include `goal`, `key_steps`, and `constraints`. diff --git a/internal/promptasset/templates/core/tool_usage.md b/internal/promptasset/templates/core/tool_usage.md index d00135997..b27a3819d 100644 --- a/internal/promptasset/templates/core/tool_usage.md +++ b/internal/promptasset/templates/core/tool_usage.md @@ -32,11 +32,7 @@ For general file operations outside of codebase exploration, use `filesystem_*` - For simple create/overwrite tasks, prefer `filesystem_write_file` with `verify_after_write=true` so one call can emit write + verification facts. - Do not use `bash` to edit files when the filesystem tools can make the change safely. - For file system structure changes inside the workspace, prefer the dedicated tools over `bash`: - - rename/move: `filesystem_move_file` (not `bash mv`) - - copy: `filesystem_copy_file` (not `bash cp`) - delete file: `filesystem_delete_file` (not `bash rm`) - - create directory: `filesystem_create_dir` (not `bash mkdir`) - - remove directory: `filesystem_remove_dir` (not `bash rmdir` / `rm -rf`) These tools record their changes for checkpoint/rollback; equivalent `bash` commands produce reduced rollback coverage. - For multi-step implementation, debugging, refactoring, or long-running work, keep task state explicit via `todo_write` (plan/add/update/set_status/claim/complete/fail) when that tool is available and the current mode permits execution todo updates. - Create todos that map to real acceptance work, not vague activity. diff --git a/internal/provider/anthropic/provider.go b/internal/provider/anthropic/provider.go index df749b868..8d8092f8e 100644 --- a/internal/provider/anthropic/provider.go +++ b/internal/provider/anthropic/provider.go @@ -41,6 +41,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } params, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/deepseek/provider.go b/internal/provider/deepseek/provider.go index 55956c7ef..1ffbab9dc 100644 --- a/internal/provider/deepseek/provider.go +++ b/internal/provider/deepseek/provider.go @@ -40,6 +40,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/deepseek/provider_more_test.go b/internal/provider/deepseek/provider_more_test.go index d517cb340..7f259ef44 100644 --- a/internal/provider/deepseek/provider_more_test.go +++ b/internal/provider/deepseek/provider_more_test.go @@ -89,6 +89,22 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { t.Fatalf("expected positive token estimate, got %+v", estimate) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } + events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/estimate.go b/internal/provider/estimate.go index 2f0499e8e..e94b59063 100644 --- a/internal/provider/estimate.go +++ b/internal/provider/estimate.go @@ -4,7 +4,9 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "math" + "strings" providertypes "neo-code/internal/provider/types" ) @@ -15,6 +17,8 @@ const ( EstimateGateAdvisory = "advisory" EstimateGateGateable = "gateable" localEstimateSlack = 1.15 + // DefaultImageInputTokenEstimate 是无法读取图片尺寸时单张图片的保守预算估算值。 + DefaultImageInputTokenEstimate = 2048 ) // EstimateSerializedPayloadTokens 基于最终协议载荷的序列化结果估算输入 token 数。 @@ -34,6 +38,109 @@ func EstimateTextTokens(text string) int { return int(math.Ceil(float64(len([]byte(text))) / 4.0 * localEstimateSlack)) } +// RequestContainsImagePart 判断请求中是否包含图片分片,供 provider 选择多模态投影估算路径。 +func RequestContainsImagePart(req providertypes.GenerateRequest) bool { + for _, message := range req.Messages { + for _, part := range message.Parts { + if part.Kind == providertypes.ContentPartImage { + return true + } + } + } + return false +} + +// ResolveRequestModel 按请求模型优先、配置默认模型兜底的规则解析实际模型名。 +func ResolveRequestModel(req providertypes.GenerateRequest, defaultModel string) string { + model := strings.TrimSpace(req.Model) + if model == "" { + model = strings.TrimSpace(defaultModel) + } + return model +} + +// EstimateProjectedInputTokens 只估算语义输入,不把图片的 base64 传输体计入 prompt token。 +func EstimateProjectedInputTokens(req providertypes.GenerateRequest, model string) (int, error) { + if strings.TrimSpace(model) == "" { + return 0, errors.New("model is empty") + } + + var textBuilder strings.Builder + textBuilder.WriteString(model) + textBuilder.WriteByte('\n') + textBuilder.WriteString(req.SystemPrompt) + textBuilder.WriteByte('\n') + + imageCount := 0 + for _, message := range req.Messages { + if err := providertypes.ValidateParts(message.Parts); err != nil { + return 0, err + } + textBuilder.WriteString(message.Role) + textBuilder.WriteByte('\n') + textBuilder.WriteString(message.ToolCallID) + textBuilder.WriteByte('\n') + for _, part := range message.Parts { + switch part.Kind { + case providertypes.ContentPartText: + textBuilder.WriteString(part.Text) + textBuilder.WriteByte('\n') + case providertypes.ContentPartImage: + imageCount++ + if part.Image != nil { + textBuilder.WriteString(string(part.Image.SourceType)) + textBuilder.WriteByte('\n') + textBuilder.WriteString(part.Image.URL) + textBuilder.WriteByte('\n') + if part.Image.Asset != nil { + textBuilder.WriteString(part.Image.Asset.ID) + textBuilder.WriteByte('\n') + textBuilder.WriteString(part.Image.Asset.MimeType) + textBuilder.WriteByte('\n') + } + } + } + } + for _, call := range message.ToolCalls { + textBuilder.WriteString(call.ID) + textBuilder.WriteByte('\n') + textBuilder.WriteString(call.Name) + textBuilder.WriteByte('\n') + textBuilder.WriteString(call.Arguments) + textBuilder.WriteByte('\n') + } + for key, value := range message.ToolMetadata { + textBuilder.WriteString(key) + textBuilder.WriteByte('=') + textBuilder.WriteString(value) + textBuilder.WriteByte('\n') + } + } + + for _, spec := range req.Tools { + textBuilder.WriteString(spec.Name) + textBuilder.WriteByte('\n') + textBuilder.WriteString(spec.Description) + textBuilder.WriteByte('\n') + normalized := NormalizeToolSchemaObject(spec.Schema) + encoded, err := json.Marshal(normalized) + if err != nil { + return 0, err + } + textBuilder.Write(encoded) + textBuilder.WriteByte('\n') + } + if req.ThinkingConfig != nil { + textBuilder.WriteString(req.ThinkingConfig.Effort) + textBuilder.WriteByte('\n') + if req.ThinkingConfig.Enabled { + textBuilder.WriteString("thinking_enabled") + } + } + + return EstimateTextTokens(textBuilder.String()) + imageCount*DefaultImageInputTokenEstimate, nil +} + // BuildGenerateRequestSignature 生成 GenerateRequest 的稳定签名,用于估算与发送阶段的请求复用匹配。 func BuildGenerateRequestSignature(req providertypes.GenerateRequest) string { encoded, err := json.Marshal(req) diff --git a/internal/provider/estimate_test.go b/internal/provider/estimate_test.go index 5fecb1351..7e830b66d 100644 --- a/internal/provider/estimate_test.go +++ b/internal/provider/estimate_test.go @@ -1,6 +1,7 @@ package provider import ( + "strings" "testing" providertypes "neo-code/internal/provider/types" @@ -40,6 +41,130 @@ func TestEstimateTextTokens(t *testing.T) { } } +func TestResolveRequestModel(t *testing.T) { + t.Parallel() + + req := providertypes.GenerateRequest{Model: " request-model "} + if got := ResolveRequestModel(req, "default-model"); got != "request-model" { + t.Fatalf("ResolveRequestModel() = %q, want request model", got) + } + + req.Model = " " + if got := ResolveRequestModel(req, " default-model "); got != "default-model" { + t.Fatalf("ResolveRequestModel() fallback = %q, want default model", got) + } +} + +func TestRequestContainsImagePart(t *testing.T) { + t.Parallel() + + textOnly := providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}, + }}, + } + if RequestContainsImagePart(textOnly) { + t.Fatal("expected text-only request to report no images") + } + withImage := textOnly + withImage.Messages[0].Parts = append(withImage.Messages[0].Parts, providertypes.NewSessionAssetImagePart("asset-1", "image/png")) + if !RequestContainsImagePart(withImage) { + t.Fatal("expected image request to report images") + } +} + +func TestEstimateProjectedInputTokensDoesNotCountBase64Transport(t *testing.T) { + t.Parallel() + + tokens, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + SystemPrompt: "You are concise.", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe this"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }}, + Tools: []providertypes.ToolSpec{{ + Name: "filesystem_read_file", + Description: "Read a file", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "path": map[string]any{"type": "string"}, + }, + }, + }}, + }, "gpt-4.1") + if err != nil { + t.Fatalf("EstimateProjectedInputTokens() error = %v", err) + } + if tokens <= DefaultImageInputTokenEstimate { + t.Fatalf("expected text and tool schema to add tokens, got %d", tokens) + } + if tokens > 10_000 { + t.Fatalf("projected estimate counted transport-sized payload, got %d", tokens) + } + + oneMiBDataURLTokens := EstimateTextTokens(strings.Repeat("x", int(EstimateDataURLTransportBytes(1024*1024, "image/png")))) + if tokens >= oneMiBDataURLTokens { + t.Fatalf("projected estimate = %d, want below data URL transport estimate %d", tokens, oneMiBDataURLTokens) + } +} + +func TestEstimateProjectedInputTokensValidatesPartsAndModel(t *testing.T) { + t.Parallel() + + if _, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{}, " "); err == nil { + t.Fatal("expected empty model error") + } + _, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{{Kind: "invalid"}}, + }}, + }, "gpt") + if err == nil { + t.Fatal("expected invalid parts error") + } + + _, err = EstimateProjectedInputTokens(providertypes.GenerateRequest{ + Model: "gpt", + Tools: []providertypes.ToolSpec{{Name: "bad", Schema: map[string]any{"unsupported": func() {}}}}, + }, "gpt") + if err == nil { + t.Fatal("expected invalid tool schema error") + } +} + +func TestEstimateProjectedInputTokensCoversMetadataAndImageSources(t *testing.T) { + t.Parallel() + + tokens, err := EstimateProjectedInputTokens(providertypes.GenerateRequest{ + SystemPrompt: "system", + Messages: []providertypes.Message{{ + Role: providertypes.RoleTool, + ToolCallID: "tool-call-1", + Parts: []providertypes.ContentPart{ + providertypes.NewRemoteImagePart("https://example.com/a.png"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + ToolCalls: []providertypes.ToolCall{{ID: "call-1", Name: "bash", Arguments: `{"cmd":"pwd"}`}}, + ToolMetadata: map[string]string{ + "exit_code": "0", + }, + }}, + ThinkingConfig: &providertypes.ThinkingConfig{Enabled: true, Effort: "medium"}, + }, "gpt-4.1") + if err != nil { + t.Fatalf("EstimateProjectedInputTokens() error = %v", err) + } + if tokens <= 2*DefaultImageInputTokenEstimate { + t.Fatalf("expected metadata text to add tokens, got %d", tokens) + } +} + func TestBuildGenerateRequestSignature(t *testing.T) { t.Parallel() @@ -68,4 +193,10 @@ func TestBuildGenerateRequestSignature(t *testing.T) { if sigA == sigC { t.Fatalf("different requests should have different signatures: %q == %q", sigA, sigC) } + + bad := reqA + bad.Tools = []providertypes.ToolSpec{{Name: "bad", Schema: map[string]any{"unsupported": func() {}}}} + if got := BuildGenerateRequestSignature(bad); got != "" { + t.Fatalf("BuildGenerateRequestSignature(bad) = %q, want empty signature", got) + } } diff --git a/internal/provider/gemini/provider.go b/internal/provider/gemini/provider.go index 2739e2cf6..e63785d7e 100644 --- a/internal/provider/gemini/provider.go +++ b/internal/provider/gemini/provider.go @@ -44,6 +44,17 @@ func (p *Provider) EstimateInputTokens( ctx context.Context, req providertypes.GenerateRequest, ) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } model, contents, genConfig, err := BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/mimo/provider.go b/internal/provider/mimo/provider.go index a5eac3580..688f2ed47 100644 --- a/internal/provider/mimo/provider.go +++ b/internal/provider/mimo/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/mimo/provider_more_test.go b/internal/provider/mimo/provider_more_test.go index 5caa66f95..5fb02c8ff 100644 --- a/internal/provider/mimo/provider_more_test.go +++ b/internal/provider/mimo/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/minimax/provider.go b/internal/provider/minimax/provider.go index 4f5b9d5d0..86cc6b1f8 100644 --- a/internal/provider/minimax/provider.go +++ b/internal/provider/minimax/provider.go @@ -40,6 +40,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/minimax/provider_more_test.go b/internal/provider/minimax/provider_more_test.go index 92e0ddc22..b375fb372 100644 --- a/internal/provider/minimax/provider_more_test.go +++ b/internal/provider/minimax/provider_more_test.go @@ -79,6 +79,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } p, err = New(provider.RuntimeConfig{ BaseURL: server.URL + "/chat/completions", APIKeyEnv: "TEST_KEY", diff --git a/internal/provider/openaicompat/glm/provider.go b/internal/provider/openaicompat/glm/provider.go index e4daae228..d5a245fc3 100644 --- a/internal/provider/openaicompat/glm/provider.go +++ b/internal/provider/openaicompat/glm/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/openaicompat/glm/provider_more_test.go b/internal/provider/openaicompat/glm/provider_more_test.go index ad8cc4e67..eebe44d62 100644 --- a/internal/provider/openaicompat/glm/provider_more_test.go +++ b/internal/provider/openaicompat/glm/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/provider/openaicompat/openaicompat_test.go b/internal/provider/openaicompat/openaicompat_test.go index db290abf6..a7c4690a4 100644 --- a/internal/provider/openaicompat/openaicompat_test.go +++ b/internal/provider/openaicompat/openaicompat_test.go @@ -276,6 +276,46 @@ func TestEstimateInputTokensReturnsAdvisoryLocalEstimate(t *testing.T) { } } +func TestEstimateInputTokensWithImageUsesProjectedEstimate(t *testing.T) { + t.Parallel() + + p, err := New(resolvedConfig("", "gpt-4.1")) + if err != nil { + t.Fatalf("New() error = %v", err) + } + reader := &singleUseSessionAssetReader{ + assets: map[string]sessionAsset{ + "asset-1": {data: []byte(strings.Repeat("x", 1024*1024)), mime: "image/png"}, + }, + } + estimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Model: "gpt-4.1", + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }}, + SessionAssetReader: reader, + }) + if err != nil { + t.Fatalf("EstimateInputTokens() error = %v", err) + } + if reader.openCount != 0 { + t.Fatalf("expected estimate not to open session asset, got %d opens", reader.openCount) + } + if estimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected text+image estimate, got %+v", estimate) + } + if estimate.EstimatedInputTokens > 10_000 { + t.Fatalf("estimate counted base64 transport payload, got %+v", estimate) + } + if estimate.GatePolicy != provider.EstimateGateAdvisory { + t.Fatalf("gate policy = %q, want %q", estimate.GatePolicy, provider.EstimateGateAdvisory) + } +} + func TestEstimateThenGenerateReusesPreparedRequest(t *testing.T) { t.Setenv(config.OpenAIDefaultAPIKeyEnv, "test-key") diff --git a/internal/provider/openaicompat/provider.go b/internal/provider/openaicompat/provider.go index e43f2e18c..15d79a552 100644 --- a/internal/provider/openaicompat/provider.go +++ b/internal/provider/openaicompat/provider.go @@ -61,6 +61,17 @@ func (p *Provider) EstimateInputTokens( } var tokens int + if provider.RequestContainsImagePart(req) { + tokens, err = provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } switch mode { case executionModeCompletions: payload, buildErr := chatcompletions.BuildRequest(ctx, p.cfg, req) diff --git a/internal/provider/openaicompat/qwen/provider.go b/internal/provider/openaicompat/qwen/provider.go index 896c4efa0..500ca25ae 100644 --- a/internal/provider/openaicompat/qwen/provider.go +++ b/internal/provider/openaicompat/qwen/provider.go @@ -37,6 +37,17 @@ func New(cfg provider.RuntimeConfig) (*Provider, error) { } func (p *Provider) EstimateInputTokens(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + if provider.RequestContainsImagePart(req) { + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, p.cfg.DefaultModel)) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + } payload, err := chatcompletions.BuildRequest(ctx, p.cfg, req) if err != nil { return providertypes.BudgetEstimate{}, err diff --git a/internal/provider/openaicompat/qwen/provider_more_test.go b/internal/provider/openaicompat/qwen/provider_more_test.go index 0cf27f966..dce1580c1 100644 --- a/internal/provider/openaicompat/qwen/provider_more_test.go +++ b/internal/provider/openaicompat/qwen/provider_more_test.go @@ -80,6 +80,21 @@ func TestProviderEstimateGenerateAndThinkingErrors(t *testing.T) { if _, err := p.EstimateInputTokens(context.Background(), req); err != nil { t.Fatalf("EstimateInputTokens() error = %v", err) } + imageEstimate, err := p.EstimateInputTokens(context.Background(), providertypes.GenerateRequest{ + Messages: []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }}, + }) + if err != nil { + t.Fatalf("EstimateInputTokens(image) error = %v", err) + } + if imageEstimate.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate { + t.Fatalf("expected projected image estimate with model text, got %+v", imageEstimate) + } + if imageEstimate.GatePolicy != provider.EstimateGateAdvisory || imageEstimate.EstimateSource != provider.EstimateSourceLocal { + t.Fatalf("unexpected image estimate metadata: %+v", imageEstimate) + } events := make(chan providertypes.StreamEvent, 8) if err := p.Generate(context.Background(), req, events); err != nil { t.Fatalf("Generate() error = %v", err) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 2fa1632b2..14c5c3dd0 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -104,11 +104,7 @@ func New(cfg Config) (*Runner, error) { toolMgr.Register(filesystem.NewGrep(workdir)) toolMgr.Register(filesystem.NewGlob(workdir)) toolMgr.Register(filesystem.NewEdit(workdir)) - toolMgr.Register(filesystem.NewMove(workdir)) - toolMgr.Register(filesystem.NewCopy(workdir)) toolMgr.Register(filesystem.NewDelete(workdir)) - toolMgr.Register(filesystem.NewCreateDir(workdir)) - toolMgr.Register(filesystem.NewRemoveDir(workdir)) toolMgr.Register(bash.New(workdir, shell, cfg.RequestTimeout)) toolMgr.Register(webfetch.New(webfetch.Config{Timeout: cfg.RequestTimeout})) toolMgr.Register(diagnosetool.New()) diff --git a/internal/runner/runner_test.go b/internal/runner/runner_test.go index d96bfe748..8d14201cb 100644 --- a/internal/runner/runner_test.go +++ b/internal/runner/runner_test.go @@ -30,14 +30,6 @@ func (m *runnerManagerAdapter) ListAvailableSpecs(context.Context, tools.SpecLis return nil, nil } -func (m *runnerManagerAdapter) MicroCompactPolicy(string) tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (m *runnerManagerAdapter) MicroCompactSummarizer(string) tools.ContentSummarizer { - return nil -} - func (m *runnerManagerAdapter) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { if m.executeFn != nil { return m.executeFn(ctx, input) diff --git a/internal/runtime/ask_session.go b/internal/runtime/ask_session.go index 418c8304b..1ada1f72d 100644 --- a/internal/runtime/ask_session.go +++ b/internal/runtime/ask_session.go @@ -45,4 +45,3 @@ func normalizeAskMessageRole(role string) string { return "assistant" } } - diff --git a/internal/runtime/ask_store.go b/internal/runtime/ask_store.go index 0a070d8fb..a349ad332 100644 --- a/internal/runtime/ask_store.go +++ b/internal/runtime/ask_store.go @@ -117,4 +117,3 @@ func (s *inMemoryAskSessionStore) cleanupExpiredLocked(now time.Time) { delete(s.sessions, sessionID) } } - diff --git a/internal/runtime/checkpoint_flow_test.go b/internal/runtime/checkpoint_flow_test.go index 6027fcdae..fddc34f1f 100644 --- a/internal/runtime/checkpoint_flow_test.go +++ b/internal/runtime/checkpoint_flow_test.go @@ -2,6 +2,7 @@ package runtime import ( "context" + "errors" "os" "path/filepath" "strings" @@ -17,6 +18,7 @@ import ( type checkpointStoreSpy struct { lastResume agentsession.ResumeCheckpoint + setResumeErr error latestResume *agentsession.ResumeCheckpoint latestResumeErr error listRecords []agentsession.CheckpointRecord @@ -56,7 +58,7 @@ func (s *checkpointStoreSpy) RestoreCheckpoint(context.Context, checkpoint.Resto func (s *checkpointStoreSpy) SetResumeCheckpoint(_ context.Context, rc agentsession.ResumeCheckpoint) error { s.lastResume = rc - return nil + return s.setResumeErr } func (s *checkpointStoreSpy) PruneExpiredCheckpoints(context.Context, string, int) (int, error) { @@ -342,18 +344,52 @@ func TestUpdateResumeCheckpoint(t *testing.T) { } } +func TestUpdateResumeCheckpointSkipsWhenStoreUnavailable(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-no-store", fixture.session) + service := &Service{} + service.updateResumeCheckpoint(context.Background(), &state, "plan", "") +} + +func TestUpdateResumeCheckpointSwallowsStoreErrorAndUsesEffectiveWorkdir(t *testing.T) { + t.Parallel() + + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-resume-workdir-fallback", fixture.session) + state.session.Workdir = " " + state.effectiveWorkdir = fixture.workdir + spy := &checkpointStoreSpy{setResumeErr: errors.New("write failed")} + service := &Service{checkpointStore: spy} + + service.updateResumeCheckpoint(context.Background(), &state, "verify", "running") + + if spy.lastResume.WorkspaceKey != agentsession.WorkspacePathKey(fixture.workdir) { + t.Fatalf("WorkspaceKey = %q, want %q", spy.lastResume.WorkspaceKey, agentsession.WorkspacePathKey(fixture.workdir)) + } +} + func TestApplyResumeCheckpointReplayPlanStrategy(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) state := newRunState("run-resume-plan", fixture.session) + currentWorkspaceKey := resolveResumeWorkspaceKey(state.session.Workdir, state.effectiveWorkdir) + currentTranscriptRevision := sessionTranscriptRevision(state.session) + currentLegacyTranscriptRevision := sessionLegacyTranscriptRevision(state.session) spy := &checkpointStoreSpy{ latestResume: &agentsession.ResumeCheckpoint{ - RunID: "run-old", - SessionID: fixture.session.ID, - Turn: 4, - Phase: "execute", - CompletionState: "", + RunID: "run-old", + WorkspaceKey: currentWorkspaceKey, + SessionID: fixture.session.ID, + Turn: 4, + Phase: "execute", + CompletionState: "", + TranscriptRevision: currentTranscriptRevision, }, } + if !resumeCheckpointMatchesState(*spy.latestResume, currentWorkspaceKey, currentTranscriptRevision, currentLegacyTranscriptRevision) { + t.Fatalf("resume checkpoint should match current state in replay-plan strategy case") + } service := &Service{ checkpointStore: spy, events: make(chan RuntimeEvent, 16), @@ -375,15 +411,23 @@ func TestApplyResumeCheckpointReplayPlanStrategy(t *testing.T) { func TestApplyResumeCheckpointVerifyClosureStrategy(t *testing.T) { fixture := newRuntimeCheckpointFixture(t) state := newRunState("run-resume-verify", fixture.session) + currentWorkspaceKey := resolveResumeWorkspaceKey(state.session.Workdir, state.effectiveWorkdir) + currentTranscriptRevision := sessionTranscriptRevision(state.session) + currentLegacyTranscriptRevision := sessionLegacyTranscriptRevision(state.session) spy := &checkpointStoreSpy{ latestResume: &agentsession.ResumeCheckpoint{ - RunID: "run-old-verify", - SessionID: fixture.session.ID, - Turn: 2, - Phase: "verify", - CompletionState: "completed", + RunID: "run-old-verify", + WorkspaceKey: currentWorkspaceKey, + SessionID: fixture.session.ID, + Turn: 2, + Phase: "verify", + CompletionState: "completed", + TranscriptRevision: currentTranscriptRevision, }, } + if !resumeCheckpointMatchesState(*spy.latestResume, currentWorkspaceKey, currentTranscriptRevision, currentLegacyTranscriptRevision) { + t.Fatalf("resume checkpoint should match current state in verify-closure strategy case") + } service := &Service{ checkpointStore: spy, events: make(chan RuntimeEvent, 16), @@ -418,11 +462,13 @@ func TestApplyResumeCheckpointSkipsUnsupportedInputs(t *testing.T) { unsupportedState := newRunState("run-unsupported-resume", fixture.session) service.checkpointStore = &checkpointStoreSpy{ latestResume: &agentsession.ResumeCheckpoint{ - RunID: "run-unknown", - SessionID: fixture.session.ID, - Turn: 1, - Phase: "stopped", - CompletionState: "completed", + RunID: "run-unknown", + WorkspaceKey: agentsession.WorkspacePathKey(fixture.session.Workdir), + SessionID: fixture.session.ID, + Turn: 1, + Phase: "stopped", + CompletionState: "completed", + TranscriptRevision: sessionTranscriptRevision(fixture.session), }, } service.applyResumeCheckpoint(context.Background(), &unsupportedState) @@ -466,11 +512,301 @@ func TestDeriveResumeBaseLifecycle(t *testing.T) { } } +func TestSessionTranscriptRevisionChangesWithoutMessageCountChange(t *testing.T) { + t.Parallel() + + base := agentsession.NewWithWorkdir("resume-token", t.TempDir()) + base.Messages = []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("same-message-count"), + }, + }, + } + base.UpdatedAt = time.Unix(1700000000, 0).UTC() + base.TaskState = agentsession.TaskState{ + Goal: "implement flow", + LastUpdatedAt: time.Unix(1700000001, 0).UTC(), + } + base.TodoVersion = 1 + + next := base + next.TaskState.NextStep = "run verification" + next.TaskState.LastUpdatedAt = time.Unix(1700000002, 0).UTC() + next.UpdatedAt = time.Unix(1700000003, 0).UTC() + next.TodoVersion = 2 + + if len(base.Messages) != len(next.Messages) { + t.Fatalf("message count changed unexpectedly: %d -> %d", len(base.Messages), len(next.Messages)) + } + if sessionTranscriptRevision(base) == sessionTranscriptRevision(next) { + t.Fatalf("expected transcript revision token to change when task/todo state changes without message-count change") + } +} + +func TestSessionTranscriptRevisionIgnoresMetadataOnlyUpdatedAt(t *testing.T) { + t.Parallel() + + base := agentsession.NewWithWorkdir("resume-token-metadata", t.TempDir()) + base.Messages = []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("same"), + }, + }, + } + base.TaskState = agentsession.TaskState{ + Goal: "same-goal", + LastUpdatedAt: time.Unix(1700001000, 0).UTC(), + } + base.UpdatedAt = time.Unix(1700002000, 0).UTC() + + next := base + next.UpdatedAt = time.Unix(1700009000, 0).UTC() + + if sessionTranscriptRevision(base) != sessionTranscriptRevision(next) { + t.Fatalf("expected transcript revision token to ignore metadata-only UpdatedAt changes") + } +} + +func TestResumeCheckpointMatchesStateSupportsLegacyMessageCountFallback(t *testing.T) { + t.Parallel() + + session := agentsession.NewWithWorkdir("legacy-resume", t.TempDir()) + session.Messages = []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("a"), + }, + }, + } + session.TaskState = agentsession.TaskState{ + Goal: "legacy fallback", + LastUpdatedAt: time.Unix(1700010000, 0).UTC(), + } + workspaceKey := agentsession.WorkspacePathKey(session.Workdir) + currentRevision := sessionTranscriptRevision(session) + legacyRevision := sessionLegacyTranscriptRevision(session) + if legacyRevision <= 0 { + t.Fatalf("legacy transcript revision should be positive, got %d", legacyRevision) + } + + resume := agentsession.ResumeCheckpoint{ + WorkspaceKey: workspaceKey, + TranscriptRevision: legacyRevision, + } + if !resumeCheckpointMatchesState(resume, workspaceKey, currentRevision, legacyRevision) { + t.Fatalf("expected legacy len(messages) checkpoint to match during compatibility fallback") + } +} + +func TestResolveResumeWorkspaceKey(t *testing.T) { + t.Parallel() + + sessionWorkdir := t.TempDir() + effectiveWorkdir := t.TempDir() + got := resolveResumeWorkspaceKey(sessionWorkdir, effectiveWorkdir) + if got != agentsession.WorkspacePathKey(sessionWorkdir) { + t.Fatalf("resolveResumeWorkspaceKey(session, effective) = %q, want %q", got, agentsession.WorkspacePathKey(sessionWorkdir)) + } + + gotFallback := resolveResumeWorkspaceKey(" ", effectiveWorkdir) + if gotFallback != agentsession.WorkspacePathKey(effectiveWorkdir) { + t.Fatalf("resolveResumeWorkspaceKey(empty, effective) = %q, want %q", gotFallback, agentsession.WorkspacePathKey(effectiveWorkdir)) + } +} + +func TestResumeCheckpointMatchesStateBranches(t *testing.T) { + t.Parallel() + + workspace := agentsession.WorkspacePathKey(t.TempDir()) + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{}, workspace, 1, 1) { + t.Fatal("expected empty checkpoint workspace to fail") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace}, "", 1, 1) { + t.Fatal("expected empty current workspace to fail") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace + "-other", TranscriptRevision: 1}, workspace, 1, 1) { + t.Fatal("expected workspace mismatch to fail") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace, TranscriptRevision: -1}, workspace, 1, 1) { + t.Fatal("expected negative checkpoint revision to fail") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace, TranscriptRevision: 1}, workspace, -1, 1) { + t.Fatal("expected negative current revision to fail") + } + if !resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace, TranscriptRevision: 9}, workspace, 9, 3) { + t.Fatal("expected exact current revision match to pass") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace, TranscriptRevision: 8}, workspace, 9, -1) { + t.Fatal("expected legacy fallback disabled on negative legacy revision") + } + if resumeCheckpointMatchesState(agentsession.ResumeCheckpoint{WorkspaceKey: workspace, TranscriptRevision: 8}, workspace, 9, 7) { + t.Fatal("expected mismatch across current/legacy revisions to fail") + } +} + +func TestSessionTranscriptRevisionIncludesTodoAndPlanState(t *testing.T) { + t.Parallel() + + required := true + base := agentsession.NewWithWorkdir("resume-rich", t.TempDir()) + base.Messages = []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("same-count"), + }, + }, + } + base.TodoVersion = 2 + base.Todos = []agentsession.TodoItem{ + { + ID: "todo-1", + Content: "prepare report", + Status: agentsession.TodoStatusInProgress, + Required: &required, + OwnerType: agentsession.TodoOwnerTypeAgent, + OwnerID: "agent-1", + FailureReason: "", + BlockedReason: agentsession.TodoBlockedReasonPermissionWait, + Revision: 3, + UpdatedAt: time.Unix(1700020000, 0).UTC(), + }, + } + base.TaskState = agentsession.TaskState{ + VerificationProfile: agentsession.VerificationProfileFixBug, + Goal: "fix bug", + Progress: []string{" inspect logs ", "write test"}, + OpenItems: []string{" verify patch"}, + NextStep: " run tests ", + Blockers: []string{" permission "}, + KeyArtifacts: []string{" report.md "}, + Decisions: []string{" keep legacy fallback "}, + UserConstraints: []string{" no destructive ops "}, + LastUpdatedAt: time.Unix(1700020001, 0).UTC(), + } + base.CurrentPlan = &agentsession.PlanArtifact{ + ID: "plan-1", + Revision: 2, + } + base.LastFullPlanRevision = 2 + base.PlanApprovalPendingFullAlign = true + + cloned := base + cloned.TaskState.Progress = append([]string(nil), base.TaskState.Progress...) + cloned.TaskState.OpenItems = append([]string(nil), base.TaskState.OpenItems...) + cloned.TaskState.Blockers = append([]string(nil), base.TaskState.Blockers...) + cloned.TaskState.KeyArtifacts = append([]string(nil), base.TaskState.KeyArtifacts...) + cloned.TaskState.Decisions = append([]string(nil), base.TaskState.Decisions...) + cloned.TaskState.UserConstraints = append([]string(nil), base.TaskState.UserConstraints...) + cloned.Todos = append([]agentsession.TodoItem(nil), base.Todos...) + if len(cloned.Todos) > 0 { + requiredCopy := *cloned.Todos[0].Required + cloned.Todos[0].Required = &requiredCopy + } + + before := sessionTranscriptRevision(base) + after := sessionTranscriptRevision(cloned) + if before != after { + t.Fatalf("expected equal revisions for equivalent rich state, got %d vs %d", before, after) + } + + cloned.Todos[0].Status = agentsession.TodoStatusCompleted + if sessionTranscriptRevision(base) == sessionTranscriptRevision(cloned) { + t.Fatal("expected todo status change to affect transcript revision") + } +} + +func TestApplyResumeCheckpointSkipsWorkspaceMismatch(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-resume-workspace-mismatch", fixture.session) + spy := &checkpointStoreSpy{ + latestResume: &agentsession.ResumeCheckpoint{ + RunID: "run-old", + WorkspaceKey: "workspace://stale", + SessionID: fixture.session.ID, + Turn: 1, + Phase: "execute", + TranscriptRevision: sessionTranscriptRevision(fixture.session), + CompletionState: "", + }, + } + service := &Service{ + checkpointStore: spy, + events: make(chan RuntimeEvent, 16), + runtimeSnapshots: make(map[string]RuntimeSnapshot), + } + + service.applyResumeCheckpoint(context.Background(), &state) + + if state.resumeNextBaseLifecycle != "" { + t.Fatalf("resumeNextBaseLifecycle = %q, want empty", state.resumeNextBaseLifecycle) + } + if strings.TrimSpace(state.pendingSystemReminder) != "" { + t.Fatalf("pendingSystemReminder = %q, want empty", state.pendingSystemReminder) + } + if len(collectRuntimeEvents(service.Events())) != 0 { + t.Fatal("expected no runtime events when workspace key mismatches") + } +} + +func TestApplyResumeCheckpointSkipsTranscriptRevisionMismatch(t *testing.T) { + fixture := newRuntimeCheckpointFixture(t) + state := newRunState("run-resume-transcript-mismatch", fixture.session) + spy := &checkpointStoreSpy{ + latestResume: &agentsession.ResumeCheckpoint{ + RunID: "run-old", + WorkspaceKey: agentsession.WorkspacePathKey(fixture.session.Workdir), + SessionID: fixture.session.ID, + Turn: 1, + Phase: "execute", + TranscriptRevision: sessionTranscriptRevision(fixture.session) + 1, + CompletionState: "", + }, + } + service := &Service{ + checkpointStore: spy, + events: make(chan RuntimeEvent, 16), + runtimeSnapshots: make(map[string]RuntimeSnapshot), + } + + service.applyResumeCheckpoint(context.Background(), &state) + + if state.resumeNextBaseLifecycle != "" { + t.Fatalf("resumeNextBaseLifecycle = %q, want empty", state.resumeNextBaseLifecycle) + } + if strings.TrimSpace(state.pendingSystemReminder) != "" { + t.Fatalf("pendingSystemReminder = %q, want empty", state.pendingSystemReminder) + } + if len(collectRuntimeEvents(service.Events())) != 0 { + t.Fatal("expected no runtime events when transcript revision mismatches") + } +} + func TestServiceRunResumeVerifyClosureBootstrapsFirstTurn(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) store := newMemoryStore() + seed, err := store.CreateSession(context.Background(), agentsession.CreateSessionInput{ + ID: "session-resume-verify", + Title: "resume verify seed", + Head: agentsession.SessionHead{ + Provider: "openai", + Model: manager.Get().CurrentModel, + Workdir: manager.Get().Workdir, + TaskState: agentsession.TaskState{ + VerificationProfile: agentsession.VerificationProfileTaskOnly, + }, + }, + }) + if err != nil { + t.Fatalf("CreateSession() error = %v", err) + } providerImpl := &scriptedProvider{ responses: []scriptedResponse{ { @@ -486,11 +822,13 @@ func TestServiceRunResumeVerifyClosureBootstrapsFirstTurn(t *testing.T) { } resumeStore := &checkpointStoreSpy{ latestResume: &agentsession.ResumeCheckpoint{ - RunID: "run-old-verify", - SessionID: "ignored-by-spy", - Turn: 2, - Phase: "verify", - CompletionState: "completed", + RunID: "run-old-verify", + WorkspaceKey: agentsession.WorkspacePathKey(seed.Workdir), + SessionID: seed.ID, + Turn: 2, + Phase: "verify", + CompletionState: "completed", + TranscriptRevision: sessionTranscriptRevision(seed), }, } service := NewWithFactory( @@ -505,8 +843,9 @@ func TestServiceRunResumeVerifyClosureBootstrapsFirstTurn(t *testing.T) { service.runtimeSnapshots = make(map[string]RuntimeSnapshot) if err := service.Run(context.Background(), UserInput{ - RunID: "run-resume-verify-first-turn", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, + RunID: "run-resume-verify-first-turn", + SessionID: seed.ID, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue")}, }); err != nil { t.Fatalf("Run() error = %v", err) } diff --git a/internal/runtime/checkpoint_resume.go b/internal/runtime/checkpoint_resume.go index 6c440de9a..476f12271 100644 --- a/internal/runtime/checkpoint_resume.go +++ b/internal/runtime/checkpoint_resume.go @@ -2,6 +2,8 @@ package runtime import ( "context" + "encoding/json" + "hash/fnv" "log" "strings" "time" @@ -26,17 +28,19 @@ func (s *Service) updateResumeCheckpoint(ctx context.Context, state *runState, p session := state.session runID := state.runID turn := state.turn + effectiveWorkdir := strings.TrimSpace(state.effectiveWorkdir) state.mu.Unlock() rc := agentsession.ResumeCheckpoint{ - ID: agentsession.NewID("rc"), - WorkspaceKey: agentsession.WorkspacePathKey(session.Workdir), - RunID: runID, - SessionID: session.ID, - Turn: turn, - Phase: phase, - CompletionState: completionState, - UpdatedAt: time.Now(), + ID: agentsession.NewID("rc"), + WorkspaceKey: resolveResumeWorkspaceKey(session.Workdir, effectiveWorkdir), + RunID: runID, + SessionID: session.ID, + Turn: turn, + Phase: phase, + CompletionState: completionState, + TranscriptRevision: sessionTranscriptRevision(session), + UpdatedAt: time.Now(), } if err := s.checkpointStore.SetResumeCheckpoint(ctx, rc); err != nil { @@ -52,6 +56,9 @@ func (s *Service) applyResumeCheckpoint(ctx context.Context, state *runState) { state.mu.Lock() sessionID := strings.TrimSpace(state.session.ID) + workspaceKey := resolveResumeWorkspaceKey(state.session.Workdir, state.effectiveWorkdir) + transcriptRevision := sessionTranscriptRevision(state.session) + legacyTranscriptRevision := sessionLegacyTranscriptRevision(state.session) state.mu.Unlock() if sessionID == "" { return @@ -61,6 +68,9 @@ func (s *Service) applyResumeCheckpoint(ctx context.Context, state *runState) { if err != nil || resume == nil { return } + if !resumeCheckpointMatchesState(*resume, workspaceKey, transcriptRevision, legacyTranscriptRevision) { + return + } phase := strings.ToLower(strings.TrimSpace(resume.Phase)) completionState := strings.ToLower(strings.TrimSpace(resume.CompletionState)) @@ -93,6 +103,172 @@ func (s *Service) applyResumeCheckpoint(ctx context.Context, state *runState) { s.emitRuntimeSnapshotUpdated(ctx, state, "resume_applied") } +// sessionTranscriptRevision 返回当前会话 transcript 的逻辑版本号,供 resume checkpoint 一致性校验使用。 +func sessionTranscriptRevision(session agentsession.Session) int64 { + currentPlanID := "" + currentPlanRevision := 0 + if session.CurrentPlan != nil { + currentPlanID = strings.TrimSpace(session.CurrentPlan.ID) + currentPlanRevision = session.CurrentPlan.Revision + } + + snapshot := resumeConsistencySnapshot{ + MessageCount: len(session.Messages), + TodoVersion: session.TodoVersion, + Todos: buildResumeTodoFingerprints(session.Todos), + TaskState: buildResumeTaskStateFingerprint(session.TaskState), + CurrentPlanID: currentPlanID, + CurrentPlanRevision: currentPlanRevision, + LastFullPlanRevision: session.LastFullPlanRevision, + PlanApprovalPendingFullAlign: session.PlanApprovalPendingFullAlign, + PlanCompletionPendingFullReview: session.PlanCompletionPendingFullReview, + PlanContextDirty: session.PlanContextDirty, + PlanRestorePendingAlign: session.PlanRestorePendingAlign, + } + // snapshot 字段仅由基础可序列化类型组成,Marshal 对该结构是稳定可达的。 + raw, _ := json.Marshal(snapshot) + hasher := fnv.New64a() + // fnv.Hash Write 对内存写入不会返回错误,这里忽略 error 以保持实现简洁。 + _, _ = hasher.Write(raw) + const positiveInt64Mask = uint64(1<<63 - 1) + return int64(hasher.Sum64() & positiveInt64Mask) +} + +// sessionLegacyTranscriptRevision 返回升级前使用的 transcript 版本语义:消息条数。 +func sessionLegacyTranscriptRevision(session agentsession.Session) int64 { + return int64(len(session.Messages)) +} + +// resolveResumeWorkspaceKey 统一计算 resume checkpoint 的工作区比较键,优先会话 workdir,缺失时回退运行时生效目录。 +func resolveResumeWorkspaceKey(sessionWorkdir string, effectiveWorkdir string) string { + workdir := strings.TrimSpace(sessionWorkdir) + if workdir == "" { + workdir = strings.TrimSpace(effectiveWorkdir) + } + return agentsession.WorkspacePathKey(workdir) +} + +// resumeCheckpointMatchesState 校验 resume checkpoint 是否仍与当前会话工作区/转录版本一致。 +func resumeCheckpointMatchesState( + resume agentsession.ResumeCheckpoint, + currentWorkspaceKey string, + currentTranscriptRevision int64, + legacyTranscriptRevision int64, +) bool { + resumeWorkspaceKey := strings.TrimSpace(resume.WorkspaceKey) + workspaceKey := strings.TrimSpace(currentWorkspaceKey) + if resumeWorkspaceKey == "" || workspaceKey == "" { + return false + } + if !strings.EqualFold(resumeWorkspaceKey, workspaceKey) { + return false + } + + if resume.TranscriptRevision < 0 || currentTranscriptRevision < 0 { + return false + } + if resume.TranscriptRevision == currentTranscriptRevision { + return true + } + // 兼容升级前 checkpoint:旧语义使用 len(messages)。 + if legacyTranscriptRevision < 0 { + return false + } + return resume.TranscriptRevision == legacyTranscriptRevision +} + +// resumeConsistencySnapshot 用于构建 resume 一致性指纹,避免仅靠消息数导致的误命中。 +type resumeConsistencySnapshot struct { + MessageCount int `json:"message_count"` + TodoVersion int `json:"todo_version"` + Todos []resumeTodoFingerprint `json:"todos,omitempty"` + TaskState resumeTaskStateFingerprint `json:"task_state"` + CurrentPlanID string `json:"current_plan_id,omitempty"` + CurrentPlanRevision int `json:"current_plan_revision,omitempty"` + LastFullPlanRevision int `json:"last_full_plan_revision,omitempty"` + PlanApprovalPendingFullAlign bool `json:"plan_approval_pending_full_align,omitempty"` + PlanCompletionPendingFullReview bool `json:"plan_completion_pending_full_review,omitempty"` + PlanContextDirty bool `json:"plan_context_dirty,omitempty"` + PlanRestorePendingAlign bool `json:"plan_restore_pending_align,omitempty"` +} + +// resumeTodoFingerprint 收敛 resume 判定所需的最小 todo 状态。 +type resumeTodoFingerprint struct { + ID string `json:"id"` + Status string `json:"status"` + Required bool `json:"required"` + OwnerType string `json:"owner_type,omitempty"` + OwnerID string `json:"owner_id,omitempty"` + FailureReason string `json:"failure_reason,omitempty"` + BlockedReason string `json:"blocked_reason,omitempty"` + Revision int64 `json:"revision"` + UpdatedAtMS int64 `json:"updated_at_ms"` +} + +// resumeTaskStateFingerprint 收敛 resume 判定所需的最小 task_state 摘要。 +type resumeTaskStateFingerprint struct { + VerificationProfile string `json:"verification_profile,omitempty"` + Goal string `json:"goal,omitempty"` + Progress []string `json:"progress,omitempty"` + OpenItems []string `json:"open_items,omitempty"` + NextStep string `json:"next_step,omitempty"` + Blockers []string `json:"blockers,omitempty"` + KeyArtifacts []string `json:"key_artifacts,omitempty"` + Decisions []string `json:"decisions,omitempty"` + UserConstraints []string `json:"user_constraints,omitempty"` + LastUpdatedAtMS int64 `json:"last_updated_at_ms"` +} + +// buildResumeTodoFingerprints 生成 resume 一致性计算所需的 todo 指纹切片。 +func buildResumeTodoFingerprints(items []agentsession.TodoItem) []resumeTodoFingerprint { + if len(items) == 0 { + return nil + } + result := make([]resumeTodoFingerprint, 0, len(items)) + for _, item := range items { + result = append(result, resumeTodoFingerprint{ + ID: strings.TrimSpace(item.ID), + Status: strings.TrimSpace(string(item.Status)), + Required: item.RequiredValue(), + OwnerType: strings.TrimSpace(item.OwnerType), + OwnerID: strings.TrimSpace(item.OwnerID), + FailureReason: strings.TrimSpace(item.FailureReason), + BlockedReason: strings.TrimSpace(string(item.BlockedReason)), + Revision: item.Revision, + UpdatedAtMS: item.UpdatedAt.UnixMilli(), + }) + } + return result +} + +// buildResumeTaskStateFingerprint 生成 resume 一致性计算所需的 task_state 指纹。 +func buildResumeTaskStateFingerprint(state agentsession.TaskState) resumeTaskStateFingerprint { + return resumeTaskStateFingerprint{ + VerificationProfile: strings.TrimSpace(string(state.VerificationProfile)), + Goal: strings.TrimSpace(state.Goal), + Progress: cloneTrimmedStringList(state.Progress), + OpenItems: cloneTrimmedStringList(state.OpenItems), + NextStep: strings.TrimSpace(state.NextStep), + Blockers: cloneTrimmedStringList(state.Blockers), + KeyArtifacts: cloneTrimmedStringList(state.KeyArtifacts), + Decisions: cloneTrimmedStringList(state.Decisions), + UserConstraints: cloneTrimmedStringList(state.UserConstraints), + LastUpdatedAtMS: state.LastUpdatedAt.UnixMilli(), + } +} + +// cloneTrimmedStringList 复制并清洗字符串切片,保证指纹输入稳定。 +func cloneTrimmedStringList(items []string) []string { + if len(items) == 0 { + return nil + } + result := make([]string, 0, len(items)) + for _, item := range items { + result = append(result, strings.TrimSpace(item)) + } + return result +} + // deriveResumeBaseLifecycle 将 checkpoint phase/completion_state 映射为恢复时首轮运行态。 func deriveResumeBaseLifecycle(phase string, completionState string) controlplane.RunState { switch strings.ToLower(strings.TrimSpace(phase)) { diff --git a/internal/runtime/controlplane/phase.go b/internal/runtime/controlplane/phase.go index 035735da6..ddc9e4b15 100644 --- a/internal/runtime/controlplane/phase.go +++ b/internal/runtime/controlplane/phase.go @@ -46,6 +46,7 @@ var allowedRunStateTransitions = map[RunState]map[RunState]struct{}{ RunStateVerify: { RunStateVerify: {}, RunStatePlan: {}, + RunStateExecute: {}, RunStateCompacting: {}, RunStateWaitingUserQuestion: {}, RunStateWaitingPermission: {}, diff --git a/internal/runtime/controlplane/phase_test.go b/internal/runtime/controlplane/phase_test.go index 376e72d0c..7fc6e25fc 100644 --- a/internal/runtime/controlplane/phase_test.go +++ b/internal/runtime/controlplane/phase_test.go @@ -13,6 +13,7 @@ func TestValidateRunStateTransitionMainlineAndGovernanceStates(t *testing.T) { {from: RunStatePlan, to: RunStateExecute}, {from: RunStateExecute, to: RunStateVerify}, {from: RunStateVerify, to: RunStatePlan}, + {from: RunStateVerify, to: RunStateExecute}, {from: RunStatePlan, to: RunStateCompacting}, {from: RunStateCompacting, to: RunStatePlan}, {from: RunStateExecute, to: RunStateWaitingPermission}, diff --git a/internal/runtime/errors_test.go b/internal/runtime/errors_test.go index db914dc07..24826fbde 100644 --- a/internal/runtime/errors_test.go +++ b/internal/runtime/errors_test.go @@ -40,3 +40,16 @@ func TestHandleRunErrorProviderErrorDoesNotWriteStdLog(t *testing.T) { } } + +func TestIsMaxTurnLimitError(t *testing.T) { + err := newMaxTurnLimitError(40) + if !IsMaxTurnLimitError(err) { + t.Fatal("expected direct max turn error to be recognized") + } + if !IsMaxTurnLimitError(errors.Join(errors.New("outer"), err)) { + t.Fatal("expected joined max turn error to be recognized") + } + if IsMaxTurnLimitError(errors.New("runtime: max turn limit reached (40)")) { + t.Fatal("plain text error should not be treated as max turn error") + } +} diff --git a/internal/runtime/events.go b/internal/runtime/events.go index 5ba148e1f..24c17fb4b 100644 --- a/internal/runtime/events.go +++ b/internal/runtime/events.go @@ -5,6 +5,7 @@ import ( "neo-code/internal/runtime/acceptgate" "neo-code/internal/runtime/controlplane" + agentsession "neo-code/internal/session" ) // EventType 标识 runtime 事件类型。 @@ -101,6 +102,12 @@ type AcceptanceDecidedPayload struct { Results []acceptgate.CheckResult `json:"results,omitempty"` } +// PlanUpdatedPayload 描述 plan 模式生成或改写后的结构化计划快照。 +type PlanUpdatedPayload struct { + CurrentPlan *agentsession.PlanArtifact `json:"current_plan"` + DisplayText string `json:"display_text,omitempty"` +} + // LedgerReconciledPayload 为账本对账预留负载。 type LedgerReconciledPayload struct { AttemptSeq int `json:"attempt_seq"` @@ -320,6 +327,8 @@ const ( EventThinkingDelta EventType = "thinking_delta" // EventAgentDone 表示 assistant 正常结束。 EventAgentDone EventType = "agent_done" + // EventPlanUpdated 表示当前结构化计划已生成或更新。 + EventPlanUpdated EventType = "plan_updated" // EventToolStart 表示工具开始执行。 EventToolStart EventType = "tool_start" // EventToolResult 表示工具执行完成并写回会话。 diff --git a/internal/runtime/hooks/command_handler.go b/internal/runtime/hooks/command_handler.go new file mode 100644 index 000000000..0523d897c --- /dev/null +++ b/internal/runtime/hooks/command_handler.go @@ -0,0 +1,348 @@ +package hooks + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "runtime" + "strings" +) + +// CommandHookPayloadVersion 定义 command hook stdin 协议版本号,变更 stdin 结构时递增。 +const CommandHookPayloadVersion = "1" + +// maxCommandStdoutBytes 限制外部命令 stdout 最大读取字节数,防止 OOM。 +const maxCommandStdoutBytes = 1 << 20 // 1 MiB + +// CommandHookPayload 是通过 stdin 传给外部命令的单行 JSON。 +type CommandHookPayload struct { + PayloadVersion string `json:"payload_version"` + HookID string `json:"hook_id"` + Point string `json:"point"` + RunID string `json:"run_id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Metadata map[string]any `json:"metadata,omitempty"` +} + +// CommandHookResponse 是外部命令通过 stdout 返回的单行 JSON。 +type CommandHookResponse struct { + Status string `json:"status"` + Message string `json:"message,omitempty"` + UpdateInput json.RawMessage `json:"update_input,omitempty"` + Annotations []string `json:"annotations,omitempty"` +} + +// CommandHookSpec 描述一个 command hook 的执行参数。 +type CommandHookSpec struct { + HookID string + Point HookPoint + Command []string // argv 模式: [binary, arg1, arg2, ...] + Shell bool // true = 通过 sh -c / powershell -Command 执行 + Workdir string +} + +// ValidateCommandParams 校验 params.command 格式。 +// 支持 []string / []any (argv 模式) 和 string + shell=true (shell 模式)。 +// 此函数是 command hook params 校验的唯一真源,config / runtime 包均应调用此函数。 +func ValidateCommandParams(params map[string]any) error { + _, _, err := ParseCommandParams(params) + return err +} + +// ParseCommandParams 解析 params.command 为 argv 数组,支持 []string / []any / string+shell 三种格式。 +// 返回解析后的 argv、是否为 shell 模式、以及校验错误。 +func ParseCommandParams(params map[string]any) (argv []string, shell bool, err error) { + if len(params) == 0 { + return nil, false, fmt.Errorf("kind command requires params.command") + } + raw, ok := params["command"] + if !ok || raw == nil { + return nil, false, fmt.Errorf("kind command requires params.command") + } + switch v := raw.(type) { + case string: + trimmed := strings.TrimSpace(v) + if trimmed == "" { + return nil, false, fmt.Errorf("kind command requires params.command") + } + shellVal, _ := params["shell"].(bool) + if !shellVal { + return nil, false, fmt.Errorf("string params.command requires params.shell=true; use array format for argv mode") + } + return []string{trimmed}, true, nil + case []string: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, s := range v { + trimmed := strings.TrimSpace(s) + if trimmed == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, trimmed) + } + return out, false, nil + case []any: + if len(v) == 0 { + return nil, false, fmt.Errorf("kind command requires non-empty params.command") + } + out := make([]string, 0, len(v)) + for _, item := range v { + s := strings.TrimSpace(fmt.Sprintf("%v", item)) + if s == "" { + return nil, false, fmt.Errorf("params.command contains empty element") + } + out = append(out, s) + } + return out, false, nil + default: + return nil, false, fmt.Errorf("params.command must be a string (with shell=true) or an array of strings") + } +} + +// BuildCommandPayload 构造传给外部命令的 stdin JSON payload。 +func BuildCommandPayload(hookID string, point HookPoint, input HookContext) CommandHookPayload { + payload := CommandHookPayload{ + PayloadVersion: CommandHookPayloadVersion, + HookID: strings.TrimSpace(hookID), + Point: string(point), + RunID: strings.TrimSpace(input.RunID), + SessionID: strings.TrimSpace(input.SessionID), + } + if len(input.Metadata) > 0 { + payload.Metadata = input.Metadata + } + return payload +} + +// ParseCommandResponse 解析外部命令 stdout 输出的单行 JSON。 +// 非 JSON 输入返回 error,调用方可退化为 exit code 兼容模式。 +func ParseCommandResponse(raw []byte) (CommandHookResponse, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return CommandHookResponse{}, fmt.Errorf("empty stdout") + } + var resp CommandHookResponse + if err := json.Unmarshal(trimmed, &resp); err != nil { + return CommandHookResponse{}, fmt.Errorf("invalid JSON: %w", err) + } + normalized := strings.ToLower(strings.TrimSpace(resp.Status)) + switch normalized { + case "pass", "block", "failed": + resp.Status = normalized + default: + return CommandHookResponse{}, fmt.Errorf("invalid status %q", resp.Status) + } + return resp, nil +} + +// RunCommandHook 执行外部命令并返回结构化的 HookResult。 +// stdout 通过管道捕获并限制为 maxCommandStdoutBytes;stderr 捕获后在失败时附加到结果。 +func RunCommandHook(ctx context.Context, spec CommandHookSpec, input HookContext) HookResult { + payload := BuildCommandPayload(spec.HookID, spec.Point, input) + payloadBytes, err := json.Marshal(payload) + if err != nil { + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultFailed, + Message: fmt.Sprintf("command hook marshal payload failed: %v", err), + Error: err.Error(), + } + } + payloadBytes = append(payloadBytes, '\n') + + cmd := buildExecCmd(ctx, spec) + cmd.Dir = spec.Workdir + cmd.Env = buildCommandEnv(spec) + cmd.Stdin = bytes.NewReader(payloadBytes) + + stdout, stderrBytes, runErr := runAndCapture(cmd) + + // stdout 过大视为执行失败 + if int64(len(stdout)) > maxCommandStdoutBytes { + msg := fmt.Sprintf("command hook stdout exceeded %d byte limit", maxCommandStdoutBytes) + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultFailed, + Message: msg, + Error: msg, + } + } + + message := strings.TrimSpace(string(stdout)) + + // 非零 exit code 优先于 JSON status(防止恶意脚本声称 pass 但实际失败) + if runErr != nil { + return buildResultFromExitCode(ctx, spec, runErr, message, stdout, stderrBytes) + } + + // exit code 0: 尝试解析 stdout JSON 协议 + resp, parseErr := ParseCommandResponse(stdout) + if parseErr == nil { + return buildResultFromResponse(spec, resp) + } + + // 退化模式: exit 0 但 stdout 非 JSON,按 pass 处理 + return HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Status: HookResultPass, + Message: message, + } +} + +// runAndCapture 执行命令,通过管道捕获 stdout(限制 maxCommandStdoutBytes),同时捕获 stderr。 +func runAndCapture(cmd *exec.Cmd) (stdout, stderr []byte, runErr error) { + cmd.Stderr = &bytes.Buffer{} + + pipe, err := cmd.StdoutPipe() + if err != nil { + return nil, nil, err + } + if err := cmd.Start(); err != nil { + return nil, nil, err + } + + // 限制读取量,防止恶意脚本 OOM + limitedReader := io.LimitReader(pipe, maxCommandStdoutBytes+1) + var stdoutBuf bytes.Buffer + _, copyErr := io.Copy(&stdoutBuf, limitedReader) + stdout = stdoutBuf.Bytes() + + waitErr := cmd.Wait() + + if stderrBuf, ok := cmd.Stderr.(*bytes.Buffer); ok { + stderr = stderrBuf.Bytes() + } + + // pipe 读取错误优先 + if copyErr != nil { + return stdout, stderr, fmt.Errorf("reading command stdout: %w", copyErr) + } + + return stdout, stderr, waitErr +} + +func buildExecCmd(ctx context.Context, spec CommandHookSpec) *exec.Cmd { + if spec.Shell { + if len(spec.Command) == 0 { + // 不应到达此处(ParseCommandParams 已校验),防御性 panic + panic("buildExecCmd: shell mode requires at least one command element") + } + shell := spec.Command[0] + if runtime.GOOS == "windows" { + return exec.CommandContext(ctx, "powershell", "-Command", shell) + } + return exec.CommandContext(ctx, "sh", "-c", shell) + } + if len(spec.Command) == 0 { + panic("buildExecCmd: command requires at least one element") + } + if len(spec.Command) == 1 { + return exec.CommandContext(ctx, spec.Command[0]) + } + return exec.CommandContext(ctx, spec.Command[0], spec.Command[1:]...) +} + +func buildCommandEnv(spec CommandHookSpec) []string { + env := []string{ + "NEOCODE_HOOK_HOOK_ID=" + spec.HookID, + "NEOCODE_HOOK_POINT=" + string(spec.Point), + "NEOCODE_HOOK_PAYLOAD_VERSION=" + CommandHookPayloadVersion, + } + if runtime.GOOS == "windows" { + for _, key := range []string{"SystemRoot", "SystemDrive", "USERPROFILE"} { + if v := os.Getenv(key); v != "" { + env = append(env, key+"="+v) + } + } + } + return env +} + +func buildResultFromResponse(spec CommandHookSpec, resp CommandHookResponse) HookResult { + result := HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Message: strings.TrimSpace(resp.Message), + } + switch resp.Status { + case "pass": + result.Status = HookResultPass + case "block": + result.Status = HookResultBlock + case "failed": + result.Status = HookResultFailed + if result.Message == "" { + result.Message = "hook returned failed status" + } + result.Error = result.Message + } + if len(resp.Annotations) > 0 { + result.Metadata.Annotations = resp.Annotations + } + if len(resp.UpdateInput) > 0 { + result.Metadata.UpdateInput = resp.UpdateInput + } + return result +} + +func buildResultFromExitCode(ctx context.Context, spec CommandHookSpec, err error, message string, stdout, stderr []byte) HookResult { + result := HookResult{ + HookID: spec.HookID, + Point: spec.Point, + Message: message, + } + // 上下文取消/超时优先判定为 failed + if ctx.Err() != nil { + result.Status = HookResultFailed + if result.Message == "" { + result.Message = fmt.Sprintf("command %v", ctx.Err()) + } + result.Error = ctx.Err().Error() + return result + } + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + code := exitErr.ExitCode() + switch code { + case 1, 2: + result.Status = HookResultBlock + result.Error = err.Error() + default: + result.Status = HookResultFailed + if result.Message == "" { + result.Message = fmt.Sprintf("command exited with code %d", code) + } + result.Error = err.Error() + } + } else { + result.Status = HookResultFailed + if result.Message == "" { + result.Message = err.Error() + } + result.Error = err.Error() + } + // 尝试从 stdout JSON 提取 message/annotations(status 仍由 exit code 决定) + if resp, parseErr := ParseCommandResponse(stdout); parseErr == nil { + if trimmed := strings.TrimSpace(resp.Message); trimmed != "" { + result.Message = trimmed + } + if len(resp.Annotations) > 0 { + result.Metadata.Annotations = resp.Annotations + } + } + // 失败时附带 stderr 便于调试 + if stderrText := strings.TrimSpace(string(stderr)); stderrText != "" && result.Message == "" { + result.Message = stderrText + } + return result +} diff --git a/internal/runtime/hooks/command_handler_test.go b/internal/runtime/hooks/command_handler_test.go new file mode 100644 index 000000000..259bc6ba4 --- /dev/null +++ b/internal/runtime/hooks/command_handler_test.go @@ -0,0 +1,1015 @@ +package hooks + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" +) + +func TestBuildCommandPayload(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, HookContext{ + RunID: "run-123", + SessionID: "sess-456", + Metadata: map[string]any{ + "tool_name": "bash", + "workdir": "/tmp", + }, + }) + if payload.PayloadVersion != CommandHookPayloadVersion { + t.Fatalf("payload_version = %q, want %q", payload.PayloadVersion, CommandHookPayloadVersion) + } + if payload.HookID != "my-hook" { + t.Fatalf("hook_id = %q, want %q", payload.HookID, "my-hook") + } + if payload.Point != string(HookPointBeforeToolCall) { + t.Fatalf("point = %q, want %q", payload.Point, HookPointBeforeToolCall) + } + if payload.RunID != "run-123" { + t.Fatalf("run_id = %q, want %q", payload.RunID, "run-123") + } + if payload.SessionID != "sess-456" { + t.Fatalf("session_id = %q, want %q", payload.SessionID, "sess-456") + } + if payload.Metadata["tool_name"] != "bash" { + t.Fatalf("metadata[tool_name] = %v, want %q", payload.Metadata["tool_name"], "bash") + } +} + +func TestBuildCommandPayloadEmptyMetadata(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("hook", HookPointSessionStart, HookContext{}) + if payload.Metadata != nil { + t.Fatalf("metadata should be nil for empty input, got %v", payload.Metadata) + } + if payload.RunID != "" { + t.Fatalf("run_id should be empty, got %q", payload.RunID) + } +} + +func TestParseCommandResponsePass(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","message":"ok"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "pass" { + t.Fatalf("status = %q, want %q", resp.Status, "pass") + } + if resp.Message != "ok" { + t.Fatalf("message = %q, want %q", resp.Message, "ok") + } +} + +func TestParseCommandResponseBlock(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"block","message":"denied"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "block" { + t.Fatalf("status = %q, want %q", resp.Status, "block") + } +} + +func TestParseCommandResponseFailed(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"failed","message":"broken"}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if resp.Status != "failed" { + t.Fatalf("status = %q, want %q", resp.Status, "failed") + } +} + +func TestParseCommandResponseWithAnnotations(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","annotations":["note1","note2"]}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.Annotations) != 2 || resp.Annotations[0] != "note1" { + t.Fatalf("annotations = %v, want [note1 note2]", resp.Annotations) + } +} + +func TestParseCommandResponseWithUpdateInput(t *testing.T) { + t.Parallel() + resp, err := ParseCommandResponse([]byte(`{"status":"pass","update_input":{"text":"rewritten"}}`)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(resp.UpdateInput) == 0 { + t.Fatal("update_input should not be empty") + } + var update struct { + Text string `json:"text"` + } + if err := json.Unmarshal(resp.UpdateInput, &update); err != nil { + t.Fatalf("unmarshal update_input: %v", err) + } + if update.Text != "rewritten" { + t.Fatalf("update_input.text = %q, want %q", update.Text, "rewritten") + } +} + +func TestParseCommandResponseInvalidStatus(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte(`{"status":"unknown"}`)) + if err == nil { + t.Fatal("expected error for invalid status") + } +} + +func TestParseCommandResponseInvalidJSON(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte(`not json`)) + if err == nil { + t.Fatal("expected error for non-JSON input") + } +} + +func TestParseCommandResponseEmptyStdout(t *testing.T) { + t.Parallel() + _, err := ParseCommandResponse([]byte{}) + if err == nil { + t.Fatal("expected error for empty input") + } +} + +func TestRunCommandHookArgvMode(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("argv mode test uses echo which is a shell builtin on Windows") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-argv", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"pass","message":"hello from argv"}`}, + Shell: false, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "hello from argv" { + t.Fatalf("message = %q, want %q", result.Message, "hello from argv") + } +} + +func TestRunCommandHookArgvModeWindows(t *testing.T) { + t.Parallel() + if runtime.GOOS != "windows" { + t.Skip("Windows-only test") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-argv-win", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"hello from argv\"}'"}, + Shell: false, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "hello from argv" { + t.Fatalf("message = %q, want %q", result.Message, "hello from argv") + } +} + +func TestRunCommandHookShellMode(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("shell mode test uses sh") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-shell", + Point: HookPointBeforeToolCall, + Command: []string{`echo '{"status":"pass","message":"from shell"}'`}, + Shell: true, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "from shell" { + t.Fatalf("message = %q, want %q", result.Message, "from shell") + } +} + +func TestRunCommandHookExitCodeNonZeroEmptyStdout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-exit3", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-exit3", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestRunCommandHookExitCodeBlock(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-exit1", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output 'blocked'; exit 1"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-exit1", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo blocked; exit 1"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultBlock, result.Message) + } + if result.Message != "blocked" { + t.Fatalf("message = %q, want %q", result.Message, "blocked") + } +} + +func TestRunCommandHookTimeout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "test-timeout", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Start-Sleep -Seconds 10"}, + } + } else { + spec = CommandHookSpec{ + HookID: "test-timeout", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "sleep 10"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestRunCommandHookEnvIsolation(t *testing.T) { + t.Parallel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "env-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "$env:NEOCODE_HOOK_HOOK_ID; $env:NEOCODE_HOOK_POINT; $env:NEOCODE_HOOK_PAYLOAD_VERSION; '{\"status\":\"pass\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "env-test", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo $NEOCODE_HOOK_HOOK_ID; echo $NEOCODE_HOOK_POINT; echo $NEOCODE_HOOK_PAYLOAD_VERSION; echo '{\"status\":\"pass\"}'"}, + } + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(result.Message, "env-test") { + t.Fatalf("expected NEOCODE_HOOK_HOOK_ID in output, got: %s", result.Message) + } + if !strings.Contains(result.Message, "before_tool_call") { + t.Fatalf("expected NEOCODE_HOOK_POINT in output, got: %s", result.Message) + } + if !strings.Contains(result.Message, CommandHookPayloadVersion) { + t.Fatalf("expected NEOCODE_HOOK_PAYLOAD_VERSION in output, got: %s", result.Message) + } +} + +func TestBuildCommandEnvContainsHookVars(t *testing.T) { + t.Parallel() + spec := CommandHookSpec{HookID: "id-123", Point: HookPointSessionEnd} + env := buildCommandEnv(spec) + envMap := make(map[string]bool) + for _, e := range env { + parts := strings.SplitN(e, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = true + } + } + if !envMap["NEOCODE_HOOK_HOOK_ID"] { + t.Fatal("missing NEOCODE_HOOK_HOOK_ID") + } + if !envMap["NEOCODE_HOOK_POINT"] { + t.Fatal("missing NEOCODE_HOOK_POINT") + } + if !envMap["NEOCODE_HOOK_PAYLOAD_VERSION"] { + t.Fatal("missing NEOCODE_HOOK_PAYLOAD_VERSION") + } +} + +func TestRunCommandHookBackwardCompatPlainText(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "compat", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output 'just a message'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "compat", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo just a message; exit 0"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if result.Message != "just a message" { + t.Fatalf("message = %q, want %q", result.Message, "just a message") + } +} + +func TestRunCommandHookAnnotationsPopulated(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "annotated", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"annotations\":[\"a1\",\"a2\"]}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "annotated", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"pass","annotations":["a1","a2"]}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if len(result.Metadata.Annotations) != 2 { + t.Fatalf("annotations count = %d, want 2; annotations: %v", len(result.Metadata.Annotations), result.Metadata.Annotations) + } +} + +func TestRunCommandHookWorkdir(t *testing.T) { + t.Parallel() + tmpDir, err := os.MkdirTemp("", "hook-workdir-test") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "workdir-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output (Get-Location).Path; exit 0"}, + Workdir: tmpDir, + } + } else { + spec = CommandHookSpec{ + HookID: "workdir-test", + Point: HookPointBeforeToolCall, + Command: []string{"pwd"}, + Workdir: tmpDir, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if !strings.Contains(strings.ToLower(result.Message), strings.ToLower(filepath.Base(tmpDir))) { + t.Fatalf("expected workdir in output, got: %s", result.Message) + } +} + +func TestBuildCommandPayloadRunSessionID(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("my-hook", HookPointBeforeToolCall, HookContext{ + RunID: "run-abc", + SessionID: "sess-xyz", + }) + if payload.RunID != "run-abc" { + t.Fatalf("run_id = %q, want %q", payload.RunID, "run-abc") + } + if payload.SessionID != "sess-xyz" { + t.Fatalf("session_id = %q, want %q", payload.SessionID, "sess-xyz") + } +} + +func TestBuildCommandPayloadEmptyRunSessionID(t *testing.T) { + t.Parallel() + payload := BuildCommandPayload("hook", HookPointSessionStart, HookContext{}) + if payload.RunID != "" { + t.Fatalf("run_id should be empty, got %q", payload.RunID) + } + if payload.SessionID != "" { + t.Fatalf("session_id should be empty, got %q", payload.SessionID) + } +} + +func TestRunCommandHookExitCodePrecedenceOverJSON(t *testing.T) { + // Security: non-zero exit code must override JSON status. + // A malicious script claiming "pass" while exiting 1 should result in block, not pass. + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "precedence-test", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"claiming pass\"}'; exit 1"}, + } + } else { + spec = CommandHookSpec{ + HookID: "precedence-test", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo '{\"status\":\"pass\",\"message\":\"claiming pass\"}'; exit 1"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q (exit code must take precedence over JSON status)", result.Status, HookResultBlock) + } + // message should still be extracted from JSON stdout + if result.Message != "claiming pass" { + t.Fatalf("message = %q, want %q (should extract message from JSON even when exit code wins)", result.Message, "claiming pass") + } +} + +func TestRunCommandHookExitCodeThreeWithJSONMessage(t *testing.T) { + // exit code 3 + JSON with message → failed status, message from JSON + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit3-json", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"pass\",\"message\":\"from json\"}'; exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit3-json", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo '{\"status\":\"pass\",\"message\":\"from json\"}'; exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "from json" { + t.Fatalf("message = %q, want %q", result.Message, "from json") + } +} + +func TestRunCommandHookStdinPayload(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdin-test", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", "$input"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdin-test", + Point: HookPointUserPromptSubmit, + Command: []string{"cat"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{ + RunID: "run-789", + SessionID: "sess-012", + Metadata: map[string]any{"workdir": "/tmp"}, + }) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if !strings.Contains(result.Message, CommandHookPayloadVersion) { + t.Fatalf("stdin payload should contain payload_version, got: %s", result.Message) + } + if !strings.Contains(result.Message, "run-789") { + t.Fatalf("stdin payload should contain run_id, got: %s", result.Message) + } + if !strings.Contains(result.Message, "sess-012") { + t.Fatalf("stdin payload should contain session_id, got: %s", result.Message) + } +} + +func TestRunCommandHookShellModeWindows(t *testing.T) { + t.Parallel() + if runtime.GOOS != "windows" { + t.Skip("Windows-only test") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "test-shell-win", + Point: HookPointBeforeToolCall, + Command: []string{`Write-Output '{"status":"pass","message":"from powershell shell"}'`}, + Shell: true, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + if result.Message != "from powershell shell" { + t.Fatalf("message = %q, want %q", result.Message, "from powershell shell") + } +} + +func TestRunCommandHookEnvIsolationNoLeak(t *testing.T) { + // Verify that host env vars like PATH, HOME, USER are NOT leaked to the subprocess. + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("PATH leaks at system level on Windows; see buildCommandEnv") + } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "env-no-leak", + Point: HookPointBeforeToolCall, + Command: []string{"env"}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q; message: %s", result.Status, HookResultPass, result.Message) + } + for _, leaked := range []string{"PATH=", "HOME=", "USER="} { + if strings.Contains(result.Message, leaked) { + t.Fatalf("host env var %q should not be leaked to subprocess, got: %s", leaked, result.Message) + } + } +} + +func TestParseCommandParamsAllBranches(t *testing.T) { + t.Parallel() + + t.Run("string with shell=true", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": "echo hi", "shell": true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !shell { + t.Fatal("expected shell=true") + } + if len(argv) != 1 || argv[0] != "echo hi" { + t.Fatalf("argv = %v, want [echo hi]", argv) + } + }) + + t.Run("string with whitespace shell=true", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": " echo hi ", "shell": true}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !shell || argv[0] != "echo hi" { + t.Fatalf("argv = %v, shell = %v", argv, shell) + } + }) + + t.Run("[]string valid", func(t *testing.T) { + t.Parallel() + argv, shell, err := ParseCommandParams(map[string]any{"command": []string{"echo", "hello"}}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if shell { + t.Fatal("expected shell=false for array") + } + if len(argv) != 2 || argv[0] != "echo" || argv[1] != "hello" { + t.Fatalf("argv = %v", argv) + } + }) + + t.Run("[]string empty", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": []string{}}) + if err == nil { + t.Fatal("expected error for empty []string") + } + }) + + t.Run("[]string with empty element", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": []string{"echo", " ", "ok"}}) + if err == nil { + t.Fatal("expected error for empty element in []string") + } + }) + + t.Run("[]any with empty element after Sprintf", func(t *testing.T) { + t.Parallel() + // nil element => fmt.Sprintf("%v", nil) => "" which is non-empty + // but empty string element => fmt.Sprintf("%v", "") => "" which is empty + _, _, err := ParseCommandParams(map[string]any{"command": []any{"echo", ""}}) + if err == nil { + t.Fatal("expected error for empty element in []any") + } + }) + + t.Run("unsupported type", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": 123}) + if err == nil { + t.Fatal("expected error for unsupported type") + } + }) + + t.Run("nil command value", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": nil}) + if err == nil { + t.Fatal("expected error for nil command") + } + }) + + t.Run("shell=false on string", func(t *testing.T) { + t.Parallel() + _, _, err := ParseCommandParams(map[string]any{"command": "echo ok", "shell": false}) + if err == nil { + t.Fatal("expected error for string without shell=true") + } + }) +} + +func TestRunCommandHookStdoutTooLarge(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + // Generate output slightly above the 1MiB limit + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdout-toolarge", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output ('x' * 1048577)"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdout-toolarge", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "printf '%1048577s' ''"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if !strings.Contains(result.Message, "byte limit") { + t.Fatalf("message should mention byte limit, got: %s", result.Message) + } +} + +func TestRunCommandHookStdinPayloadWithMetadata(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "stdin-meta", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", "$input"}, + } + } else { + spec = CommandHookSpec{ + HookID: "stdin-meta", + Point: HookPointUserPromptSubmit, + Command: []string{"cat"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{ + RunID: "run-meta", + SessionID: "sess-meta", + Metadata: map[string]any{"tool_name": "bash", "workdir": "/tmp"}, + }) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if !strings.Contains(result.Message, `"tool_name"`) { + t.Fatalf("stdin should contain tool_name metadata, got: %s", result.Message) + } +} + +func TestRunCommandHookExitCodeTwoBlocks(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit2", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "exit 2"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit2", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "exit 2"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q", result.Status, HookResultBlock) + } + if result.Error == "" { + t.Fatal("expected Error to be set for exit code 2 block") + } +} + +func TestRunCommandHookExitCodeZeroEmptyStdout(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit0-empty", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", ""}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit0-empty", + Point: HookPointBeforeToolCall, + Command: []string{"true"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } +} + +func TestRunCommandHookNonExistentBinary(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + spec := CommandHookSpec{ + HookID: "no-such-binary", + Point: HookPointBeforeToolCall, + Command: []string{"nonexistent_binary_xyz_12345"}, + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Error == "" { + t.Fatal("expected Error to be set for nonexistent binary") + } +} + +func TestRunCommandHookBlockWithMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "block-msg", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"block\",\"message\":\"not allowed\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "block-msg", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"block","message":"not allowed"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultBlock { + t.Fatalf("status = %q, want %q", result.Status, HookResultBlock) + } + if result.Message != "not allowed" { + t.Fatalf("message = %q, want %q", result.Message, "not allowed") + } +} + +func TestRunCommandHookFailedStatusWithDefaultMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "failed-default", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"failed\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "failed-default", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"failed"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "hook returned failed status" { + t.Fatalf("message = %q, want default failed message", result.Message) + } + if result.Error != "hook returned failed status" { + t.Fatalf("error = %q, want default failed message", result.Error) + } +} + +func TestRunCommandHookFailedStatusWithCustomMessage(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "failed-custom", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Output '{\"status\":\"failed\",\"message\":\"custom error\"}'"}, + } + } else { + spec = CommandHookSpec{ + HookID: "failed-custom", + Point: HookPointBeforeToolCall, + Command: []string{"echo", `{"status":"failed","message":"custom error"}`}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } + if result.Message != "custom error" { + t.Fatalf("message = %q, want %q", result.Message, "custom error") + } + if result.Error != "custom error" { + t.Fatalf("error = %q, want %q", result.Error, "custom error") + } +} + +func TestRunCommandHookPassWithAnnotationsAndUpdateInput(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + payload := `{"status":"pass","message":"ok","annotations":["a1","a2"],"update_input":{"text":"rewritten"}}` + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "full-output", + Point: HookPointUserPromptSubmit, + Command: []string{"powershell", "-Command", fmt.Sprintf("Write-Output '%s'", payload)}, + } + } else { + spec = CommandHookSpec{ + HookID: "full-output", + Point: HookPointUserPromptSubmit, + Command: []string{"echo", payload}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultPass { + t.Fatalf("status = %q, want %q", result.Status, HookResultPass) + } + if result.Message != "ok" { + t.Fatalf("message = %q, want %q", result.Message, "ok") + } + if len(result.Metadata.Annotations) != 2 || result.Metadata.Annotations[0] != "a1" { + t.Fatalf("annotations = %v, want [a1 a2]", result.Metadata.Annotations) + } + if len(result.Metadata.UpdateInput) == 0 { + t.Fatal("expected UpdateInput to be populated") + } +} + +func TestRunCommandHookExitCodeThreeWithStderr(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + var spec CommandHookSpec + if runtime.GOOS == "windows" { + spec = CommandHookSpec{ + HookID: "exit3-stderr", + Point: HookPointBeforeToolCall, + Command: []string{"powershell", "-Command", "Write-Error 'bad thing'; exit 3"}, + } + } else { + spec = CommandHookSpec{ + HookID: "exit3-stderr", + Point: HookPointBeforeToolCall, + Command: []string{"sh", "-c", "echo bad thing >&2; exit 3"}, + } + } + result := RunCommandHook(ctx, spec, HookContext{}) + if result.Status != HookResultFailed { + t.Fatalf("status = %q, want %q", result.Status, HookResultFailed) + } +} + +func TestBuildCommandEnvContainsNEOCODEVars(t *testing.T) { + t.Parallel() + spec := CommandHookSpec{HookID: "id-env", Point: HookPointSessionStart} + env := buildCommandEnv(spec) + envMap := make(map[string]bool) + for _, e := range env { + parts := strings.SplitN(e, "=", 2) + if len(parts) == 2 { + envMap[parts[0]] = true + } + } + for _, key := range []string{"NEOCODE_HOOK_HOOK_ID", "NEOCODE_HOOK_POINT", "NEOCODE_HOOK_PAYLOAD_VERSION"} { + if !envMap[key] { + t.Fatalf("missing %s in env", key) + } + } + if runtime.GOOS == "windows" { + for _, key := range []string{"SystemRoot", "SystemDrive", "USERPROFILE"} { + if os.Getenv(key) != "" && !envMap[key] { + t.Fatalf("missing Windows env var %s", key) + } + } + } +} + +func TestValidateCommandParams(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + params map[string]any + wantErr bool + }{ + {"nil params", nil, true}, + {"empty params", map[string]any{}, true}, + {"missing command", map[string]any{"other": "val"}, true}, + {"empty string command", map[string]any{"command": ""}, true}, + {"string without shell", map[string]any{"command": "echo ok"}, true}, + {"string with shell", map[string]any{"command": "echo ok", "shell": true}, false}, + {"empty array", map[string]any{"command": []any{}}, true}, + {"valid array", map[string]any{"command": []any{"echo", "ok"}}, false}, + {"array with empty element", map[string]any{"command": []any{"echo", ""}}, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + err := ValidateCommandParams(tc.params) + if (err != nil) != tc.wantErr { + t.Fatalf("ValidateCommandParams() error = %v, wantErr %v", err, tc.wantErr) + } + }) + } +} diff --git a/internal/runtime/hooks/executor.go b/internal/runtime/hooks/executor.go index 8f32faef9..2586483a3 100644 --- a/internal/runtime/hooks/executor.go +++ b/internal/runtime/hooks/executor.go @@ -79,6 +79,9 @@ func (e *Executor) Run(ctx context.Context, point HookPoint, input HookContext) if spec.Scope == HookScopeUser || spec.Scope == HookScopeRepo { hookInput = sanitizeUserHookContext(hookInput) } + if spec.Matcher != nil && !spec.Matcher.Match(hookInput) { + continue + } if spec.Mode == HookModeAsync || spec.Mode == HookModeAsyncRewake { e.runAsync(ctx, spec, hookInput) continue @@ -340,6 +343,7 @@ func sanitizeUserHookContext(input HookContext) HookContext { "point": {}, "tool_call_id": {}, "tool_name": {}, + "tool_arguments_preview": {}, "is_error": {}, "error_class": {}, "result_content_preview": {}, diff --git a/internal/runtime/hooks/executor_test.go b/internal/runtime/hooks/executor_test.go index 6a1372e4a..a3d52438d 100644 --- a/internal/runtime/hooks/executor_test.go +++ b/internal/runtime/hooks/executor_test.go @@ -955,10 +955,11 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -971,6 +972,9 @@ func TestExecutorSanitizeUserHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for user hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for user hook context") } @@ -999,10 +1003,11 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { RunID: "run-1", SessionID: "session-1", Metadata: map[string]any{ - "tool_name": "bash", - "tool_arguments": "--secret-token=abc", - "capability_token": "should-not-leak", - "workdir": "/tmp/work", + "tool_name": "bash", + "tool_arguments": "--secret-token=abc", + "tool_arguments_preview": "token=***", + "capability_token": "should-not-leak", + "workdir": "/tmp/work", }, }) @@ -1012,7 +1017,38 @@ func TestExecutorSanitizeRepoHookContext(t *testing.T) { if _, exists := captured.Metadata["tool_arguments"]; exists { t.Fatal("tool_arguments should be stripped for repo hook context") } + if got := captured.Metadata["tool_arguments_preview"]; got != "token=***" { + t.Fatalf("tool_arguments_preview = %v, want token=***", got) + } if _, exists := captured.Metadata["capability_token"]; exists { t.Fatal("capability_token should be stripped for repo hook context") } } + +func TestExecutorSkipsHookWhenMatcherMissed(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + executor := NewExecutor(registry, nil, 100*time.Millisecond) + if err := registry.Register(HookSpec{ + ID: "matcher-hook", + Point: HookPointBeforeToolCall, + Scope: HookScopeUser, + Matcher: &HookMatcher{ToolNames: []string{"bash"}}, + Handler: func(context.Context, HookContext) HookResult { + return HookResult{Status: HookResultPass, Message: "should-not-run"} + }, + }); err != nil { + t.Fatalf("Register() error = %v", err) + } + + output := executor.Run(context.Background(), HookPointBeforeToolCall, HookContext{ + Metadata: map[string]any{"tool_name": "filesystem"}, + }) + if output.Blocked { + t.Fatalf("Blocked = true, want false") + } + if len(output.Results) != 0 { + t.Fatalf("len(Results) = %d, want 0 when matcher missed", len(output.Results)) + } +} diff --git a/internal/runtime/hooks/matcher.go b/internal/runtime/hooks/matcher.go new file mode 100644 index 000000000..16c133228 --- /dev/null +++ b/internal/runtime/hooks/matcher.go @@ -0,0 +1,281 @@ +package hooks + +import ( + "fmt" + "regexp" + "strings" +) + +const ( + // MaxHookMatcherRegexLength 限制 tool_name_regex 单条表达式长度,避免超长输入拖慢匹配。 + MaxHookMatcherRegexLength = 256 +) + +const ( + hookMatcherFieldToolName = "tool_name" + hookMatcherFieldToolNameRegex = "tool_name_regex" + hookMatcherFieldArgumentsContains = "arguments_contains" + hookMatcherMetadataToolName = "tool_name" + hookMatcherMetadataArguments = "tool_arguments_preview" +) + +// HookMatcher 描述编译后的 hook 匹配器。 +type HookMatcher struct { + ToolNames []string + ToolNameRegex []*regexp.Regexp + ArgumentsContains []string +} + +// HasHookMatcherConfig 判断 matcher 配置是否包含至少一个非空维度。 +func HasHookMatcherConfig(raw map[string]any) bool { + if len(raw) == 0 { + return false + } + names := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + if len(names) > 0 { + return true + } + regexes := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + if len(regexes) > 0 { + return true + } + contains := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + return len(contains) > 0 +} + +// ValidateHookMatcher 校验 matcher 配置在指定点位上是否合法。 +func ValidateHookMatcher(point HookPoint, raw map[string]any) error { + _, err := CompileHookMatcher(point, raw) + return err +} + +// CompileHookMatcher 将 matcher 原始配置编译为可执行结构,并在点位能力上做 fail-fast 校验。 +func CompileHookMatcher(point HookPoint, raw map[string]any) (*HookMatcher, error) { + if len(raw) == 0 { + return nil, nil + } + if err := validateHookMatcherFields(raw); err != nil { + return nil, err + } + if !HasHookMatcherConfig(raw) { + return nil, fmt.Errorf("match contains no recognized matcher fields (expected: tool_name, tool_name_regex, arguments_contains)") + } + capability, ok := HookPointCapabilities(point) + if !ok { + return nil, fmt.Errorf("point %q is not supported", point) + } + + namesRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolName) + regexRaw := readHookMatcherStringValues(raw, hookMatcherFieldToolNameRegex) + containsRaw := readHookMatcherStringValues(raw, hookMatcherFieldArgumentsContains) + + if len(namesRaw) > 0 && !capability.Matcher.ToolName { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolName) + } + if len(regexRaw) > 0 && !capability.Matcher.ToolNameRegex { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldToolNameRegex) + } + if len(containsRaw) > 0 && !capability.Matcher.ArgumentsContains { + return nil, fmt.Errorf("point %q does not support matcher field %q", point, hookMatcherFieldArgumentsContains) + } + + matcher := &HookMatcher{ + ToolNames: normalizeHookMatcherValues(namesRaw), + ArgumentsContains: normalizeHookMatcherValues(containsRaw), + } + for _, expression := range regexRaw { + trimmed := strings.TrimSpace(expression) + if trimmed == "" { + continue + } + if len(trimmed) > MaxHookMatcherRegexLength { + return nil, fmt.Errorf( + "matcher field %q expression length exceeds %d", + hookMatcherFieldToolNameRegex, + MaxHookMatcherRegexLength, + ) + } + compiled, err := regexp.Compile(trimmed) + if err != nil { + return nil, fmt.Errorf("matcher field %q has invalid regex %q: %w", hookMatcherFieldToolNameRegex, trimmed, err) + } + matcher.ToolNameRegex = append(matcher.ToolNameRegex, compiled) + } + if matcher.IsEmpty() { + return nil, fmt.Errorf("match must include at least one non-empty matcher field") + } + return matcher, nil +} + +// validateHookMatcherFields 校验 matcher 配置中不存在未支持字段,避免拼写错误被静默忽略。 +func validateHookMatcherFields(raw map[string]any) error { + if len(raw) == 0 { + return nil + } + for key := range raw { + normalized := strings.ToLower(strings.TrimSpace(key)) + switch normalized { + case hookMatcherFieldToolName, hookMatcherFieldToolNameRegex, hookMatcherFieldArgumentsContains: + continue + default: + return fmt.Errorf( + "match contains unknown field %q (allowed: tool_name, tool_name_regex, arguments_contains)", + key, + ) + } + } + return nil +} + +// IsEmpty 判断 matcher 是否包含可执行维度。 +func (m *HookMatcher) IsEmpty() bool { + if m == nil { + return true + } + return len(m.ToolNames) == 0 && len(m.ToolNameRegex) == 0 && len(m.ArgumentsContains) == 0 +} + +// Match 根据 HookContext 执行 matcher 判定;字段间为 AND,同字段多值为 OR。 +func (m *HookMatcher) Match(input HookContext) bool { + if m == nil || m.IsEmpty() { + return true + } + toolName := strings.TrimSpace(readHookMatcherMetadataString(input.Metadata, hookMatcherMetadataToolName)) + if len(m.ToolNames) > 0 { + if toolName == "" || !containsEqualFold(m.ToolNames, toolName) { + return false + } + } + if len(m.ToolNameRegex) > 0 { + if toolName == "" { + return false + } + matched := false + for _, compiled := range m.ToolNameRegex { + if compiled.MatchString(toolName) { + matched = true + break + } + } + if !matched { + return false + } + } + if len(m.ArgumentsContains) > 0 { + argumentsPreview := strings.ToLower(strings.TrimSpace(readHookMatcherMetadataString( + input.Metadata, + hookMatcherMetadataArguments, + ))) + if argumentsPreview == "" { + return false + } + matched := false + for _, fragment := range m.ArgumentsContains { + if strings.Contains(argumentsPreview, fragment) { + matched = true + break + } + } + if !matched { + return false + } + } + return true +} + +// readHookMatcherStringValues 读取 matcher 字段中的字符串集合,兼容 string / []string / []any。 +func readHookMatcherStringValues(raw map[string]any, key string) []string { + if len(raw) == 0 { + return nil + } + value, ok := raw[key] + if !ok || value == nil { + return nil + } + switch typed := value.(type) { + case string: + if strings.TrimSpace(typed) == "" { + return nil + } + return []string{typed} + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if strings.TrimSpace(item) == "" { + continue + } + out = append(out, item) + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + if item == nil { + continue + } + text := strings.TrimSpace(fmt.Sprintf("%v", item)) + if text == "" { + continue + } + out = append(out, text) + } + return out + default: + text := strings.TrimSpace(fmt.Sprintf("%v", typed)) + if text == "" { + return nil + } + return []string{text} + } +} + +// normalizeHookMatcherValues 将 matcher 词条规范为小写并剔除空值。 +func normalizeHookMatcherValues(values []string) []string { + if len(values) == 0 { + return nil + } + normalized := make([]string, 0, len(values)) + for _, value := range values { + text := strings.ToLower(strings.TrimSpace(value)) + if text == "" { + continue + } + normalized = append(normalized, text) + } + return normalized +} + +// containsEqualFold 判断字符串列表是否包含目标值(忽略大小写)。 +func containsEqualFold(values []string, target string) bool { + normalizedTarget := strings.ToLower(strings.TrimSpace(target)) + if normalizedTarget == "" { + return false + } + for _, value := range values { + if strings.EqualFold(strings.TrimSpace(value), normalizedTarget) { + return true + } + } + return false +} + +// readHookMatcherMetadataString 从 metadata 中读取字符串,兼容大小写键和非字符串值。 +func readHookMatcherMetadataString(metadata map[string]any, key string) string { + if len(metadata) == 0 { + return "" + } + normalizedKey := strings.ToLower(strings.TrimSpace(key)) + if normalizedKey == "" { + return "" + } + if value, ok := metadata[normalizedKey]; ok && value != nil { + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + for currentKey, value := range metadata { + if !strings.EqualFold(strings.TrimSpace(currentKey), normalizedKey) || value == nil { + continue + } + return strings.TrimSpace(fmt.Sprintf("%v", value)) + } + return "" +} diff --git a/internal/runtime/hooks/matcher_test.go b/internal/runtime/hooks/matcher_test.go new file mode 100644 index 000000000..cf85dfd3f --- /dev/null +++ b/internal/runtime/hooks/matcher_test.go @@ -0,0 +1,344 @@ +package hooks + +import ( + "regexp" + "testing" +) + +func TestHasHookMatcherConfig(t *testing.T) { + t.Parallel() + + if HasHookMatcherConfig(nil) { + t.Fatal("nil matcher config should be false") + } + if HasHookMatcherConfig(map[string]any{}) { + t.Fatal("empty matcher config should be false") + } + if !HasHookMatcherConfig(map[string]any{"tool_name": "bash"}) { + t.Fatal("tool_name matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"tool_name_regex": []any{"^bash$"}}) { + t.Fatal("tool_name_regex matcher should be true") + } + if !HasHookMatcherConfig(map[string]any{"arguments_contains": []string{"rm -rf"}}) { + t.Fatal("arguments_contains matcher should be true") + } + if HasHookMatcherConfig(map[string]any{"tool_name": " "}) { + t.Fatal("whitespace-only tool_name should be false") + } +} + +func TestCompileHookMatcherAndMatch(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": []any{"bash", "filesystem"}, + "tool_name_regex": []string{`^(bash|shell)$`}, + "arguments_contains": []string{"rm -rf"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher to be compiled") + } + + if !matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "sudo rm -rf /tmp/test", + }, + }) { + t.Fatal("expected matcher to pass for matching metadata") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "bash", + "tool_arguments_preview": "echo hello", + }, + }) { + t.Fatal("expected matcher to fail when arguments_contains not matched") + } + if matcher.Match(HookContext{ + Metadata: map[string]any{ + "tool_name": "filesystem", + "tool_arguments_preview": "rm -rf /tmp", + }, + }) { + t.Fatal("expected matcher to fail when tool_name_regex not matched") + } +} + +func TestCompileHookMatcherValidation(t *testing.T) { + t.Parallel() + + if _, err := CompileHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start tool_name matcher to be rejected") + } + + if _, err := CompileHookMatcher(HookPointAfterToolResult, map[string]any{ + "arguments_contains": []string{"rm -rf"}, + }); err == nil { + t.Fatal("expected after_tool_result arguments_contains to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": "(", + }); err == nil { + t.Fatal("expected invalid regex to fail") + } + + longRegex := make([]byte, MaxHookMatcherRegexLength+1) + for i := range longRegex { + longRegex[i] = 'a' + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": string(longRegex), + }); err == nil { + t.Fatal("expected overlong regex to fail") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_names": "bash", + }); err == nil { + t.Fatal("expected unknown matcher field to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "unknown": "value", + }); err == nil { + t.Fatal("expected completely unknown matcher field to be rejected") + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_names": []any{"filesystem"}, + }); err == nil { + t.Fatal("expected mixed matcher fields with typo to be rejected") + } + + if _, err := CompileHookMatcher(HookPointBeforeToolCall, nil); err != nil { + t.Fatal("nil raw should succeed with nil matcher") + } + if _, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{}); err != nil { + t.Fatal("empty raw should succeed with nil matcher") + } +} + +func TestValidateHookMatcher(t *testing.T) { + t.Parallel() + + if err := ValidateHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + }); err != nil { + t.Fatalf("ValidateHookMatcher() error = %v", err) + } + if err := ValidateHookMatcher(HookPointSessionStart, map[string]any{ + "tool_name": "bash", + }); err == nil { + t.Fatal("expected session_start matcher to fail validation") + } +} + +func TestIsEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.IsEmpty() { + t.Fatal("nil matcher should be empty") + } + if !(&HookMatcher{}).IsEmpty() { + t.Fatal("zero-value matcher should be empty") + } + if (&HookMatcher{ToolNames: []string{"bash"}}).IsEmpty() { + t.Fatal("matcher with tool_name should not be empty") + } +} + +func TestMatchNilAndEmpty(t *testing.T) { + t.Parallel() + + var nilMatcher *HookMatcher + if !nilMatcher.Match(HookContext{}) { + t.Fatal("nil matcher should match everything") + } + empty := &HookMatcher{} + if !empty.Match(HookContext{}) { + t.Fatal("empty matcher should match everything") + } +} + +func TestMatchSingleDimension(t *testing.T) { + t.Parallel() + + t.Run("tool_name only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ToolNames: []string{"bash", "filesystem"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name metadata missing") + } + }) + + t.Run("tool_name_regex only", func(t *testing.T) { + t.Parallel() + compiled := regexp.MustCompile(`^(bash|shell)$`) + m := &HookMatcher{ToolNameRegex: []*regexp.Regexp{compiled}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_name": "bash"}}) { + t.Fatal("expected regex match for bash") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_name": "python"}}) { + t.Fatal("expected regex no match for python") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when tool_name missing for regex") + } + }) + + t.Run("arguments_contains only", func(t *testing.T) { + t.Parallel() + m := &HookMatcher{ArgumentsContains: []string{"rm -rf", "sudo"}} + if !m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "sudo rm -rf /tmp"}}) { + t.Fatal("expected arguments_contains match") + } + if m.Match(HookContext{Metadata: map[string]any{"tool_arguments_preview": "echo hello"}}) { + t.Fatal("expected arguments_contains no match") + } + if m.Match(HookContext{Metadata: map[string]any{}}) { + t.Fatal("expected no match when arguments_preview missing") + } + }) +} + +func TestReadHookMatcherStringValues(t *testing.T) { + t.Parallel() + + if got := readHookMatcherStringValues(nil, "x"); len(got) != 0 { + t.Fatal("nil raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{}, "x"); len(got) != 0 { + t.Fatal("empty raw should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": nil}, "x"); len(got) != 0 { + t.Fatal("nil value should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": " "}, "x"); len(got) != 0 { + t.Fatal("whitespace-only string should return nil") + } + if got := readHookMatcherStringValues(map[string]any{"x": 42}, "x"); len(got) != 1 || got[0] != "42" { + t.Fatalf("int value should be converted to string, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": []any{" a ", nil, 123}}, "x"); len(got) != 2 || got[0] != "a" || got[1] != "123" { + t.Fatalf("[]any with mixed values, got %v", got) + } + if got := readHookMatcherStringValues(map[string]any{"x": "hello"}, "y"); len(got) != 0 { + t.Fatal("missing key should return nil") + } +} + +func TestNormalizeHookMatcherValues(t *testing.T) { + t.Parallel() + + if got := normalizeHookMatcherValues(nil); len(got) != 0 { + t.Fatal("nil values should return nil") + } + if got := normalizeHookMatcherValues([]string{}); len(got) != 0 { + t.Fatal("empty values should return nil") + } + if got := normalizeHookMatcherValues([]string{" ", "\t"}); len(got) != 0 { + t.Fatal("whitespace-only values should return empty") + } + if got := normalizeHookMatcherValues([]string{" BASH ", "", " Filesystem "}); len(got) != 2 || got[0] != "bash" || got[1] != "filesystem" { + t.Fatalf("mixed values should be normalized, got %v", got) + } +} + +func TestContainsEqualFold(t *testing.T) { + t.Parallel() + + if containsEqualFold(nil, "bash") { + t.Fatal("nil values should not match") + } + if containsEqualFold([]string{"bash"}, "") { + t.Fatal("empty target should not match") + } + if containsEqualFold([]string{"bash"}, " ") { + t.Fatal("whitespace-only target should not match") + } + if !containsEqualFold([]string{"BASH", "FILESYSTEM"}, " bash ") { + t.Fatal("case-insensitive match should work") + } + if containsEqualFold([]string{"bash"}, "python") { + t.Fatal("non-matching should return false") + } +} + +func TestReadHookMatcherMetadataString(t *testing.T) { + t.Parallel() + + if got := readHookMatcherMetadataString(nil, "x"); got != "" { + t.Fatal("nil metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{}, "x"); got != "" { + t.Fatal("empty metadata should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": nil}, "x"); got != "" { + t.Fatal("nil value should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": 123}, "x"); got != "123" { + t.Fatalf("non-string value should be converted, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, " "); got != "" { + t.Fatal("empty key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"x": "hello"}, ""); got != "" { + t.Fatal("empty string key should return empty") + } + if got := readHookMatcherMetadataString(map[string]any{"TOOL_NAME": "bash"}, "tool_name"); got != "bash" { + t.Fatalf("case-insensitive key lookup failed, got %q", got) + } + if got := readHookMatcherMetadataString(map[string]any{"y": "hello"}, "x"); got != "" { + t.Fatal("missing key should return empty") + } +} + +func TestCompileHookMatcherRegexWhitespaceSkipped(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name": "bash", + "tool_name_regex": []string{" ", "\t"}, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled even when regex values are whitespace-only") + } + if len(matcher.ToolNameRegex) != 0 { + t.Fatalf("expected empty tool_name_regex slice, got %d entries", len(matcher.ToolNameRegex)) + } +} + +func TestCompileHookMatcherRegexOnly(t *testing.T) { + t.Parallel() + + matcher, err := CompileHookMatcher(HookPointBeforeToolCall, map[string]any{ + "tool_name_regex": `^bash`, + }) + if err != nil { + t.Fatalf("CompileHookMatcher() error = %v", err) + } + if matcher == nil { + t.Fatal("expected matcher compiled for regex-only config") + } + if !matcher.Match(HookContext{Metadata: map[string]any{"tool_name": "bash-script"}}) { + t.Fatal("expected regex to match") + } +} diff --git a/internal/runtime/hooks/result.go b/internal/runtime/hooks/result.go index e224182db..697f47bef 100644 --- a/internal/runtime/hooks/result.go +++ b/internal/runtime/hooks/result.go @@ -1,6 +1,9 @@ package hooks -import "time" +import ( + "encoding/json" + "time" +) // HookResultStatus 表示单个 hook 的执行结果状态。 type HookResultStatus string @@ -36,6 +39,10 @@ type HookResultMetadata struct { OriginalStatus string BlockDowngraded bool GuardSignal bool + + // P6 command hook 协议字段 + Annotations []string // stdout JSON "annotations" 数组 + UpdateInput json.RawMessage // stdout JSON "update_input" 原始字节 } // RunOutput 是一次点位执行的聚合结果。 diff --git a/internal/runtime/hooks/types.go b/internal/runtime/hooks/types.go index e51d5d6f2..2601f21e7 100644 --- a/internal/runtime/hooks/types.go +++ b/internal/runtime/hooks/types.go @@ -2,6 +2,7 @@ package hooks import ( "context" + "sort" "strings" "time" ) @@ -44,22 +45,77 @@ type HookPointCapability struct { CanAnnotate bool CanUpdateInput bool UserAllowed bool + Matcher HookMatcherCapability +} + +// HookMatcherCapability 描述点位可用的 matcher 维度。 +type HookMatcherCapability struct { + ToolName bool + ToolNameRegex bool + ArgumentsContains bool } var hookPointCapabilities = map[HookPoint]HookPointCapability{ - HookPointBeforeToolCall: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAfterToolResult: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforeCompletionDecision: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointAcceptGate: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointBeforePermissionDecision: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointAfterToolFailure: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionStart: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSessionEnd: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointUserPromptSubmit: {CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true}, - HookPointPreCompact: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointPostCompact: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, - HookPointSubAgentStart: {CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false}, - HookPointSubAgentStop: {CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true}, + HookPointBeforeToolCall: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointAfterToolResult: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointBeforeCompletionDecision: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointAcceptGate: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointBeforePermissionDecision: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: false, + }, + }, + HookPointAfterToolFailure: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{ + ToolName: true, ToolNameRegex: true, ArgumentsContains: true, + }, + }, + HookPointSessionStart: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSessionEnd: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointUserPromptSubmit: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: true, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointPreCompact: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointPostCompact: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStart: { + CanBlock: true, CanAnnotate: true, CanUpdateInput: false, UserAllowed: false, + Matcher: HookMatcherCapability{}, + }, + HookPointSubAgentStop: { + CanBlock: false, CanAnnotate: true, CanUpdateInput: false, UserAllowed: true, + Matcher: HookMatcherCapability{}, + }, } // HookScope 描述 hook 的权限/上下文裁剪等级。 @@ -141,6 +197,7 @@ type HookSpec struct { Timeout time.Duration FailurePolicy FailurePolicy Handler HookHandler + Matcher *HookMatcher } // normalizeAndValidate 将 HookSpec 归一化并校验当前阶段可用字段。 @@ -224,3 +281,30 @@ func HookPointCapabilities(point HookPoint) (HookPointCapability, bool) { capability, ok := hookPointCapabilities[point] return capability, ok } + +// ListHookPoints 返回所有已注册的 hook 点位(按字符串排序,保证确定性)。 +func ListHookPoints() []HookPoint { + points := make([]HookPoint, 0, len(hookPointCapabilities)) + for point := range hookPointCapabilities { + points = append(points, point) + } + sort.Slice(points, func(i, j int) bool { + return points[i] < points[j] + }) + return points +} + +// IsUserAllowed 返回指定点位是否允许 user scope hook 挂载。 +func IsUserAllowed(point HookPoint) bool { + capability, ok := hookPointCapabilities[point] + if !ok { + return false + } + return capability.UserAllowed +} + +// IsRepoAllowed 返回指定点位是否允许 repo scope hook 挂载。 +// 当前 repo 与 user 共享相同的 allowed 策略。 +func IsRepoAllowed(point HookPoint) bool { + return IsUserAllowed(point) +} diff --git a/internal/runtime/hooks/types_test.go b/internal/runtime/hooks/types_test.go index cd222fa86..3664133cb 100644 --- a/internal/runtime/hooks/types_test.go +++ b/internal/runtime/hooks/types_test.go @@ -217,6 +217,9 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("before_permission_decision should allow block") } + if !capability.Matcher.ToolName || !capability.Matcher.ToolNameRegex { + t.Fatal("before_permission_decision should support tool_name/tool_name_regex matcher") + } capability, ok = HookPointCapabilities(HookPointAfterToolFailure) if !ok { @@ -225,6 +228,9 @@ func TestHookPointCapabilities(t *testing.T) { if capability.CanBlock { t.Fatal("after_tool_failure should be observe-only") } + if !capability.Matcher.ArgumentsContains { + t.Fatal("after_tool_failure should support arguments_contains matcher") + } capability, ok = HookPointCapabilities(HookPointBeforeCompletionDecision) if !ok { @@ -241,8 +247,111 @@ func TestHookPointCapabilities(t *testing.T) { if !capability.CanBlock { t.Fatal("accept_gate should allow block") } + if capability.Matcher.ToolName || capability.Matcher.ToolNameRegex || capability.Matcher.ArgumentsContains { + t.Fatal("accept_gate should not expose matcher dimensions") + } if _, exists := HookPointCapabilities(HookPoint("unknown")); exists { t.Fatal("unknown hook point should not have capability") } } + +func TestListHookPointsReturnsSortedSlice(t *testing.T) { + t.Parallel() + + points := ListHookPoints() + if len(points) == 0 { + t.Fatal("expected at least one hook point") + } + + // 验证包含所有已知点位。 + expected := []HookPoint{ + HookPointBeforeToolCall, + HookPointAfterToolResult, + HookPointBeforeCompletionDecision, + HookPointAcceptGate, + HookPointBeforePermissionDecision, + HookPointAfterToolFailure, + HookPointSessionStart, + HookPointSessionEnd, + HookPointUserPromptSubmit, + HookPointPreCompact, + HookPointPostCompact, + HookPointSubAgentStart, + HookPointSubAgentStop, + } + if len(points) != len(expected) { + t.Fatalf("ListHookPoints() len = %d, want %d", len(points), len(expected)) + } + + // 验证排序。 + for i := 1; i < len(points); i++ { + if points[i] < points[i-1] { + t.Fatalf("ListHookPoints() not sorted: %q < %q at index %d", points[i], points[i-1], i) + } + } + + // 验证包含所有点位。 + pointSet := make(map[HookPoint]struct{}, len(points)) + for _, p := range points { + pointSet[p] = struct{}{} + } + for _, p := range expected { + if _, ok := pointSet[p]; !ok { + t.Fatalf("ListHookPoints() missing point %q", p) + } + } +} + +func TestIsUserAllowed(t *testing.T) { + t.Parallel() + + allowedPoints := []HookPoint{ + HookPointBeforeToolCall, + HookPointAfterToolResult, + HookPointBeforeCompletionDecision, + HookPointAcceptGate, + HookPointAfterToolFailure, + HookPointSessionStart, + HookPointSessionEnd, + HookPointUserPromptSubmit, + HookPointPostCompact, + HookPointSubAgentStop, + } + for _, p := range allowedPoints { + if !IsUserAllowed(p) { + t.Fatalf("IsUserAllowed(%q) = false, want true", p) + } + } + + disallowedPoints := []HookPoint{ + HookPointBeforePermissionDecision, + HookPointPreCompact, + HookPointSubAgentStart, + } + for _, p := range disallowedPoints { + if IsUserAllowed(p) { + t.Fatalf("IsUserAllowed(%q) = true, want false", p) + } + } + + // 未知点位应返回 false。 + if IsUserAllowed(HookPoint("unknown_point")) { + t.Fatal("IsUserAllowed(unknown) should return false") + } +} + +func TestIsRepoAllowed(t *testing.T) { + t.Parallel() + + // 当前 repo 与 user 共享策略。 + if !IsRepoAllowed(HookPointBeforeToolCall) { + t.Fatal("IsRepoAllowed(before_tool_call) should be true") + } + if IsRepoAllowed(HookPointPreCompact) { + t.Fatal("IsRepoAllowed(pre_compact) should be false") + } + if IsRepoAllowed(HookPoint("unknown")) { + t.Fatal("IsRepoAllowed(unknown) should be false") + } +} diff --git a/internal/runtime/hooks_integration.go b/internal/runtime/hooks_integration.go index b3badecf0..8a2cfdbe0 100644 --- a/internal/runtime/hooks_integration.go +++ b/internal/runtime/hooks_integration.go @@ -2,8 +2,10 @@ package runtime import ( "context" + "encoding/json" "strings" + providertypes "neo-code/internal/provider/types" runtimehooks "neo-code/internal/runtime/hooks" ) @@ -229,10 +231,15 @@ func (s *Service) recordUserHookAnnotations(state *runState, output runtimehooks continue } message := strings.TrimSpace(result.Message) - if message == "" { - continue + if message != "" { + notes = append(notes, message) + } + for _, annotation := range result.Metadata.Annotations { + trimmed := strings.TrimSpace(annotation) + if trimmed != "" { + notes = append(notes, trimmed) + } } - notes = append(notes, message) } if len(notes) == 0 { return @@ -241,3 +248,41 @@ func (s *Service) recordUserHookAnnotations(state *runState, output runtimehooks state.hookAnnotations = append(state.hookAnnotations, notes...) state.mu.Unlock() } + +// applyCommandHookUpdateInput 检查 hook 输出中的 update_input 并应用到用户输入 parts。 +// 当前仅支持 user_prompt_submit 点位;update_input 格式: {"text": "..."} 替换文本内容。 +func applyCommandHookUpdateInput(output runtimehooks.RunOutput, parts []providertypes.ContentPart) []providertypes.ContentPart { + if len(output.Results) == 0 { + return parts + } + for _, result := range output.Results { + if len(result.Metadata.UpdateInput) == 0 { + continue + } + cap, ok := runtimehooks.HookPointCapabilities(result.Point) + if !ok || !cap.CanUpdateInput { + continue + } + var update struct { + Text string `json:"text"` + } + if err := json.Unmarshal(result.Metadata.UpdateInput, &update); err != nil { + continue + } + if update.Text == "" { + continue + } + replaced := false + newParts := make([]providertypes.ContentPart, 0, len(parts)) + for _, part := range parts { + if !replaced && part.Kind == providertypes.ContentPartText { + newParts = append(newParts, providertypes.NewTextPart(update.Text)) + replaced = true + } else { + newParts = append(newParts, part) + } + } + return newParts + } + return parts +} diff --git a/internal/runtime/hooks_integration_test.go b/internal/runtime/hooks_integration_test.go index aa820bd41..73bc279aa 100644 --- a/internal/runtime/hooks_integration_test.go +++ b/internal/runtime/hooks_integration_test.go @@ -1253,3 +1253,142 @@ func TestEmitSubAgentStopHookNilServiceNoop(t *testing.T) { Error: "noop", }) } + +func TestApplyCommandHookUpdateInput(t *testing.T) { + t.Parallel() + + t.Run("empty results returns parts unchanged", func(t *testing.T) { + t.Parallel() + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(runtimehooks.RunOutput{}, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged", got) + } + }) + + t.Run("replaces first text part when CanUpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"rewritten"}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "rewritten" { + t.Fatalf("got %v, want text replaced to 'rewritten'", got) + } + }) + + t.Run("ignores when CanUpdateInput is false", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointBeforeToolCall, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"should not apply"}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for non-CanUpdateInput point", got) + } + }) + + t.Run("ignores invalid JSON in UpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`not json`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for invalid JSON", got) + } + }) + + t.Run("ignores empty text in UpdateInput", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":""}`), + }, + }}, + } + parts := []providertypes.ContentPart{providertypes.NewTextPart("original")} + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 1 || got[0].Text != "original" { + t.Fatalf("got %v, want original parts unchanged for empty text", got) + } + }) + + t.Run("only replaces first text part", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"new"}`), + }, + }}, + } + parts := []providertypes.ContentPart{ + providertypes.NewTextPart("first"), + providertypes.NewTextPart("second"), + } + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 2 { + t.Fatalf("got len %d, want 2", len(got)) + } + if got[0].Text != "new" { + t.Fatalf("first part text = %q, want 'new'", got[0].Text) + } + if got[1].Text != "second" { + t.Fatalf("second part text = %q, want 'second' (unchanged)", got[1].Text) + } + }) + + t.Run("preserves non-text parts", func(t *testing.T) { + t.Parallel() + output := runtimehooks.RunOutput{ + Results: []runtimehooks.HookResult{{ + Point: runtimehooks.HookPointUserPromptSubmit, + Status: runtimehooks.HookResultPass, + Metadata: runtimehooks.HookResultMetadata{ + UpdateInput: []byte(`{"text":"replaced"}`), + }, + }}, + } + parts := []providertypes.ContentPart{ + providertypes.NewRemoteImagePart("https://example.com/img.png"), + providertypes.NewTextPart("original"), + } + got := applyCommandHookUpdateInput(output, parts) + if len(got) != 2 { + t.Fatalf("got len %d, want 2", len(got)) + } + if got[0].Kind != providertypes.ContentPartImage { + t.Fatalf("first part kind = %q, want image (unchanged)", got[0].Kind) + } + if got[1].Text != "replaced" { + t.Fatalf("second part text = %q, want 'replaced'", got[1].Text) + } + }) +} diff --git a/internal/runtime/input_prepare.go b/internal/runtime/input_prepare.go index 0752a370e..8b9e70827 100644 --- a/internal/runtime/input_prepare.go +++ b/internal/runtime/input_prepare.go @@ -148,6 +148,7 @@ func (p sessionInputPreparer) Prepare( for _, image := range input.Images { sessionImages = append(sessionImages, agentsession.PrepareImageInput{ Path: strings.TrimSpace(image.Path), + AssetID: strings.TrimSpace(image.AssetID), MimeType: strings.TrimSpace(image.MimeType), }) } diff --git a/internal/runtime/max_turn_error.go b/internal/runtime/max_turn_error.go index d52616d80..831c023ff 100644 --- a/internal/runtime/max_turn_error.go +++ b/internal/runtime/max_turn_error.go @@ -1,6 +1,9 @@ package runtime -import "fmt" +import ( + "errors" + "fmt" +) // maxTurnLimitError 表示 Run 达到 runtime.max_turns 上限后触发的受控停止错误。 type maxTurnLimitError struct { @@ -21,3 +24,9 @@ func (e maxTurnLimitError) Limit() int { func newMaxTurnLimitError(limit int) error { return maxTurnLimitError{limit: limit} } + +// IsMaxTurnLimitError 判断错误链是否来自 runtime.max_turns 受控停止。 +func IsMaxTurnLimitError(err error) bool { + var target maxTurnLimitError + return errors.As(err, &target) +} diff --git a/internal/runtime/plan_approval.go b/internal/runtime/plan_approval.go index f1be721dd..e083ade68 100644 --- a/internal/runtime/plan_approval.go +++ b/internal/runtime/plan_approval.go @@ -7,6 +7,25 @@ import ( "time" ) +var ( + // ErrPlanApprovalCurrentPlanMissing 表示当前会话没有可批准的计划。 + ErrPlanApprovalCurrentPlanMissing = errors.New("runtime plan approval current plan missing") + // ErrPlanApprovalPlanIDMismatch 表示客户端批准的计划 ID 已不是当前计划。 + ErrPlanApprovalPlanIDMismatch = errors.New("runtime plan approval plan id mismatch") + // ErrPlanApprovalRevisionMismatch 表示客户端批准的 revision 已过期或非法。 + ErrPlanApprovalRevisionMismatch = errors.New("runtime plan approval revision mismatch") + // ErrPlanApprovalStatusInvalid 表示当前计划状态不允许批准。 + ErrPlanApprovalStatusInvalid = errors.New("runtime plan approval status invalid") +) + +// IsPlanApprovalInvalidError 判断错误是否属于可预期的计划审批业务拒绝。 +func IsPlanApprovalInvalidError(err error) bool { + return errors.Is(err, ErrPlanApprovalCurrentPlanMissing) || + errors.Is(err, ErrPlanApprovalPlanIDMismatch) || + errors.Is(err, ErrPlanApprovalRevisionMismatch) || + errors.Is(err, ErrPlanApprovalStatusInvalid) +} + // ApproveCurrentPlan 显式批准当前完整计划 revision,并安排下一轮做一次完整计划对齐。 func (s *Service) ApproveCurrentPlan(ctx context.Context, input ApproveCurrentPlanInput) error { if err := ctx.Err(); err != nil { diff --git a/internal/runtime/planning.go b/internal/runtime/planning.go index 976905416..716c39fad 100644 --- a/internal/runtime/planning.go +++ b/internal/runtime/planning.go @@ -174,8 +174,9 @@ func decodePlanTurnOutput(jsonText string) (planTurnOutput, error) { // stripPlanningJSONObjectText 从原始回复中移除结构化 JSON,并尽量保留自然段落间距。 func stripPlanningJSONObjectText(text string, candidate extractedPlanningJSONObject) string { - before := strings.TrimSpace(text[:candidate.Start]) - after := strings.TrimSpace(text[candidate.End:]) + start, end := planningJSONObjectRemovalRange(text, candidate) + before := strings.TrimSpace(text[:start]) + after := strings.TrimSpace(text[end:]) switch { case before == "": return after @@ -186,6 +187,28 @@ func stripPlanningJSONObjectText(text string, candidate extractedPlanningJSONObj } } +// planningJSONObjectRemovalRange 扩展结构化 JSON 的剥离范围,避免 HTML 注释外壳泄漏到可见计划正文。 +func planningJSONObjectRemovalRange(text string, candidate extractedPlanningJSONObject) (int, int) { + start := candidate.Start + end := candidate.End + if start < 0 || end < start || end > len(text) { + return candidate.Start, candidate.End + } + + prefix := text[:start] + open := strings.LastIndex(prefix, "") + if closeOffset < 0 || strings.TrimSpace(suffix[:closeOffset]) != "" { + return start, end + } + return open, end + closeOffset + len("-->") +} + // extractPlanningJSONObjectIfPresent 在文本中提取首个满足指定顶层键契约的 JSON 对象。 func extractPlanningJSONObjectIfPresent(text string, requiredKey string) (extractedPlanningJSONObject, bool) { start := strings.IndexByte(text, '{') @@ -258,11 +281,43 @@ func buildPlanArtifact(current *agentsession.PlanArtifact, output planTurnOutput return plan, nil } -// resolvePlanDisplayText 优先保留模型对计划的额外说明文本,缺失时回退为规范计划正文。 -func resolvePlanDisplayText(output planTurnOutput, spec agentsession.PlanSpec) string { - display := strings.TrimSpace(output.DisplayText) - if display != "" { - return display +// renderPlanMarkdown 将结构化计划渲染为前端可直接展示的规范 Markdown。 +func renderPlanMarkdown(spec agentsession.PlanSpec) string { + spec, err := agentsession.NormalizePlanSpec(spec) + if err != nil { + return "" + } + sections := make([]string, 0, 4) + sections = append(sections, "### 目标\n\n"+spec.Goal) + if len(spec.Steps) > 0 { + sections = append(sections, "### 实施步骤\n\n"+renderMarkdownBulletList(spec.Steps)) + } + if len(spec.Constraints) > 0 { + sections = append(sections, "### 约束\n\n"+renderMarkdownBulletList(spec.Constraints)) + } + if len(spec.OpenQuestions) > 0 { + sections = append(sections, "### 未决问题\n\n"+renderMarkdownBulletList(spec.OpenQuestions)) + } + return strings.TrimSpace(strings.Join(sections, "\n\n")) +} + +// renderMarkdownBulletList 将计划字段中的字符串列表渲染为 Markdown 无序列表。 +func renderMarkdownBulletList(items []string) string { + lines := make([]string, 0, len(items)) + for _, item := range items { + trimmed := strings.TrimSpace(item) + if trimmed == "" { + continue + } + lines = append(lines, "- "+trimmed) + } + return strings.Join(lines, "\n") +} + +// resolvePlanDisplayText 在解析出机器可读计划后固定返回规范化计划正文,不保留模型额外说明。 +func resolvePlanDisplayText(_ planTurnOutput, spec agentsession.PlanSpec) string { + if markdown := renderPlanMarkdown(spec); markdown != "" { + return markdown } return strings.TrimSpace(agentsession.RenderPlanContent(spec)) } @@ -344,16 +399,16 @@ func rememberFullPlanRevision(session *agentsession.Session) bool { // approveCurrentPlan 显式批准当前 draft revision,并安排下一轮做一次完整计划对齐。 func approveCurrentPlan(session *agentsession.Session, planID string, revision int) error { if session == nil || session.CurrentPlan == nil { - return fmt.Errorf("runtime: current plan does not exist") + return fmt.Errorf("%w: current plan does not exist", ErrPlanApprovalCurrentPlanMissing) } if strings.TrimSpace(planID) == "" || strings.TrimSpace(session.CurrentPlan.ID) != strings.TrimSpace(planID) { - return fmt.Errorf("runtime: current plan id does not match") + return fmt.Errorf("%w: current plan id does not match", ErrPlanApprovalPlanIDMismatch) } if revision <= 0 || session.CurrentPlan.Revision != revision { - return fmt.Errorf("runtime: current plan revision does not match") + return fmt.Errorf("%w: current plan revision does not match", ErrPlanApprovalRevisionMismatch) } if session.CurrentPlan.Status != agentsession.PlanStatusDraft { - return fmt.Errorf("runtime: current plan status %q cannot be approved", session.CurrentPlan.Status) + return fmt.Errorf("%w: current plan status %q cannot be approved", ErrPlanApprovalStatusInvalid, session.CurrentPlan.Status) } session.CurrentPlan = session.CurrentPlan.Clone() session.CurrentPlan.Status = agentsession.PlanStatusApproved diff --git a/internal/runtime/planning_test.go b/internal/runtime/planning_test.go index 1c74fb8c4..e09d77563 100644 --- a/internal/runtime/planning_test.go +++ b/internal/runtime/planning_test.go @@ -1,6 +1,7 @@ package runtime import ( + "errors" "reflect" "strings" "testing" @@ -133,6 +134,29 @@ func TestMaybeParsePlanTurnOutputIgnoresBraceTextAndKeepsExplanation(t *testing. } } +func TestMaybeParsePlanTurnOutputStripsHTMLCommentJSON(t *testing.T) { + t.Parallel() + + markdown := "### Goal\n\nShip plan display\n\n### Steps\n\n- Align prompts" + text := markdown + "\n\n" + output, ok, err := maybeParsePlanTurnOutput(providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart(text)}, + }) + if err != nil { + t.Fatalf("maybeParsePlanTurnOutput() error = %v", err) + } + if !ok { + t.Fatal("expected HTML comment plan JSON to be detected") + } + if output.PlanSpec.Goal != "Ship plan display" { + t.Fatalf("PlanSpec.Goal = %q", output.PlanSpec.Goal) + } + if output.DisplayText != markdown { + t.Fatalf("DisplayText = %q, want %q", output.DisplayText, markdown) + } +} + func TestMaybeParsePlanTurnOutputFallsBackWhenSummaryIsInvalid(t *testing.T) { t.Parallel() @@ -512,7 +536,8 @@ func TestApproveCurrentPlanValidationErrors(t *testing.T) { t.Parallel() session := agentsession.New("approve validation") - if err := approveCurrentPlan(&session, "plan-1", 1); err == nil { + if err := approveCurrentPlan(&session, "plan-1", 1); !errors.Is(err, ErrPlanApprovalCurrentPlanMissing) || + !IsPlanApprovalInvalidError(err) { t.Fatal("expected error when current plan does not exist") } @@ -526,15 +551,18 @@ func TestApproveCurrentPlanValidationErrors(t *testing.T) { }, } - if err := approveCurrentPlan(&session, "plan-2", 2); err == nil { - t.Fatal("expected id mismatch error") + if err := approveCurrentPlan(&session, "plan-2", 2); !errors.Is(err, ErrPlanApprovalPlanIDMismatch) || + !IsPlanApprovalInvalidError(err) { + t.Fatalf("expected id mismatch error, got %v", err) } - if err := approveCurrentPlan(&session, "plan-1", 1); err == nil { - t.Fatal("expected revision mismatch error") + if err := approveCurrentPlan(&session, "plan-1", 1); !errors.Is(err, ErrPlanApprovalRevisionMismatch) || + !IsPlanApprovalInvalidError(err) { + t.Fatalf("expected revision mismatch error, got %v", err) } session.CurrentPlan.Status = agentsession.PlanStatusApproved - if err := approveCurrentPlan(&session, "plan-1", 2); err == nil { - t.Fatal("expected status mismatch error") + if err := approveCurrentPlan(&session, "plan-1", 2); !errors.Is(err, ErrPlanApprovalStatusInvalid) || + !IsPlanApprovalInvalidError(err) { + t.Fatalf("expected status mismatch error, got %v", err) } } diff --git a/internal/runtime/repo_hooks.go b/internal/runtime/repo_hooks.go index 9fb5140cd..6983d46d0 100644 --- a/internal/runtime/repo_hooks.go +++ b/internal/runtime/repo_hooks.go @@ -320,25 +320,11 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { if strings.TrimSpace(item.ID) == "" { return fmt.Errorf("id is required") } - point := strings.ToLower(strings.TrimSpace(item.Point)) - switch point { - case string(runtimehooks.HookPointBeforeToolCall), - string(runtimehooks.HookPointAfterToolResult), - string(runtimehooks.HookPointBeforeCompletionDecision), - string(runtimehooks.HookPointAcceptGate), - string(runtimehooks.HookPointBeforePermissionDecision), - string(runtimehooks.HookPointAfterToolFailure), - string(runtimehooks.HookPointSessionStart), - string(runtimehooks.HookPointSessionEnd), - string(runtimehooks.HookPointUserPromptSubmit), - string(runtimehooks.HookPointPreCompact), - string(runtimehooks.HookPointPostCompact), - string(runtimehooks.HookPointSubAgentStart), - string(runtimehooks.HookPointSubAgentStop): - default: + point := runtimehooks.HookPoint(strings.ToLower(strings.TrimSpace(item.Point))) + if _, ok := runtimehooks.HookPointCapabilities(point); !ok { return fmt.Errorf("point %q is not supported", item.Point) } - if capability, ok := runtimehooks.HookPointCapabilities(runtimehooks.HookPoint(point)); ok && !capability.UserAllowed { + if !runtimehooks.IsRepoAllowed(point) { return fmt.Errorf("point %q does not allow repo hooks", item.Point) } if strings.ToLower(strings.TrimSpace(item.Scope)) != repoHookScopeValue { @@ -373,31 +359,25 @@ func validateRepoHookItem(item config.RuntimeHookItemConfig) error { default: return fmt.Errorf("handler %q is not supported", item.Handler) } - if handler == "warn_on_tool_call" && !runtimeHasWarnOnToolCallTargets(item.Params) { - return fmt.Errorf("handler %q requires params.tool_name or params.tool_names", item.Handler) + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } case repoHookKindCommand: - if strings.TrimSpace(readHookParamString(item.Params, "command")) == "" { - return fmt.Errorf("kind command requires params.command") + if err := runtimehooks.ValidateCommandParams(item.Params); err != nil { + return err } - } - return nil -} - -// runtimeHasWarnOnToolCallTargets 判断 warn_on_tool_call 是否配置了至少一个目标工具。 -func runtimeHasWarnOnToolCallTargets(params map[string]any) bool { - if len(params) == 0 { - return false - } - if name := strings.TrimSpace(readHookParamString(params, "tool_name")); name != "" { - return true - } - for _, value := range readHookParamStringSlice(params, "tool_names") { - if strings.TrimSpace(value) != "" { - return true + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(point, item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } } - return false + return nil } // evaluateWorkspaceTrust 根据 trust store 判断 workspace 是否可信并附带容错诊断。 diff --git a/internal/runtime/repo_hooks_test.go b/internal/runtime/repo_hooks_test.go index bb50f7f1c..c7df2ade8 100644 --- a/internal/runtime/repo_hooks_test.go +++ b/internal/runtime/repo_hooks_test.go @@ -596,25 +596,46 @@ func TestValidateRepoHookItemRejectsExternalKindsWithP6LiteMessage(t *testing.T) } } -func TestRuntimeHasWarnOnToolCallTargetsBranches(t *testing.T) { - cases := []struct { - name string - params map[string]any - want bool - }{ - {name: "nil", params: nil, want: false}, - {name: "tool_name", params: map[string]any{"tool_name": "bash"}, want: true}, - {name: "tool_name blank", params: map[string]any{"tool_name": " "}, want: false}, - {name: "tool_names", params: map[string]any{"tool_names": []any{"bash"}}, want: true}, - {name: "tool_names blank", params: map[string]any{"tool_names": []any{" "}}, want: false}, +func TestValidateRepoHookItemAllowsWarnOnToolCallWithMatchOnly(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-warn-match", + Point: "before_tool_call", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "warn_on_tool_call", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Match: map[string]any{ + "tool_name": "bash", + }, } - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - if got := runtimeHasWarnOnToolCallTargets(tc.params); got != tc.want { - t.Fatalf("runtimeHasWarnOnToolCallTargets() = %v, want %v", got, tc.want) - } - }) + if err := validateRepoHookItem(item); err != nil { + t.Fatalf("validateRepoHookItem() error = %v", err) + } +} + +func TestValidateRepoHookItemRejectsUnsupportedMatcherDimension(t *testing.T) { + t.Parallel() + + item := config.RuntimeHookItemConfig{ + ID: "repo-session-match", + Point: "session_start", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "add_context_note", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{"note": "repo"}, + Match: map[string]any{ + "tool_name": "bash", + }, + } + if err := validateRepoHookItem(item); err == nil { + t.Fatal("expected unsupported matcher dimension to fail") } } @@ -808,3 +829,271 @@ func TestRepoHookEventEmittersAndHelpers(t *testing.T) { t.Fatalf("coalesceHookMessage(blank) = %q, want empty", got) } } + +// TestRepoHookPointSingleSourceConsistency 验证 repo 侧与 runtime hooks 包的点位定义一致。 +// 新增 hook point 时只需修改 runtime hooks 包,repo 验证自动接受。 +func TestRepoHookPointSingleSourceConsistency(t *testing.T) { + t.Parallel() + + allPoints := runtimehooks.ListHookPoints() + if len(allPoints) == 0 { + t.Fatal("expected at least one hook point from runtime hooks package") + } + + base := config.RuntimeHookItemConfig{ + ID: "repo-consistency", + Scope: "repo", + Kind: "builtin", + Mode: "sync", + Handler: "add_context_note", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{"note": "consistency check"}, + } + + for _, point := range allPoints { + point := point + t.Run(string(point), func(t *testing.T) { + t.Parallel() + if !runtimehooks.IsRepoAllowed(point) { + // 跳过不允许 repo 的点位。 + return + } + item := base.Clone() + item.Point = string(point) + if err := validateRepoHookItem(item); err != nil { + t.Fatalf("repo validation rejected point %q: %v", point, err) + } + }) + } + + // 验证 accept_gate 在 runtime hooks 包中存在且允许 repo。 + acceptGateCap, ok := runtimehooks.HookPointCapabilities(runtimehooks.HookPointAcceptGate) + if !ok { + t.Fatal("accept_gate not found in runtime hooks capabilities") + } + if !acceptGateCap.UserAllowed { + t.Fatal("accept_gate should allow repo hooks (via UserAllowed)") + } +} + +func TestEvaluateWorkspaceTrustPermissionAndNormalizeErrorBranches(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + workspace := filepath.Join(homeDir, "workspace") + if err := os.MkdirAll(workspace, 0o755); err != nil { + t.Fatalf("mkdir workspace: %v", err) + } + + storePath := resolveTrustedWorkspacesPath() + if err := os.MkdirAll(filepath.Dir(storePath), 0o755); err != nil { + t.Fatalf("mkdir trust store dir: %v", err) + } + + // 分支:trust store 是目录,触发 read 或 permissions 错误。 + if err := os.MkdirAll(storePath, 0o755); err != nil { + t.Fatalf("mkdir store as dir: %v", err) + } + decision := evaluateWorkspaceTrust(workspace) + if decision.Trusted { + t.Fatal("expected untrusted when trust store is a directory") + } + if strings.TrimSpace(decision.InvalidReason) == "" { + t.Fatal("expected non-empty invalid reason when trust store is a directory") + } + + // 清理并写入包含相对路径 entry 的 trust store。 + if err := os.RemoveAll(storePath); err != nil { + t.Fatalf("remove store dir: %v", err) + } + store := trustedWorkspaceStore{ + Version: repoHooksTrustStoreVersion, + Workspaces: []string{"relative/path/not/absolute"}, + } + rawStore, err := json.Marshal(store) + if err != nil { + t.Fatalf("marshal trust store: %v", err) + } + if err := os.WriteFile(storePath, rawStore, 0o644); err != nil { + t.Fatalf("write trust store: %v", err) + } + + // 分支:workspaces 中的 entry 无法 normalize(相对路径)。 + decision = evaluateWorkspaceTrust(workspace) + if decision.Trusted { + t.Fatal("expected untrusted when workspaces entry is relative") + } + if !strings.Contains(decision.InvalidReason, "invalid") { + t.Fatalf("expected invalid path error, got: %s", decision.InvalidReason) + } +} + +func TestLoadRepoHookItemsErrorBranches(t *testing.T) { + workspace := t.TempDir() + + // 分支:YAML 结构不匹配(未知字段)。 + hooksPath := filepath.Join(workspace, ".neocode", "hooks.yaml") + if err := os.MkdirAll(filepath.Dir(hooksPath), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(hooksPath, []byte(` +hooks: + items: + - id: bad + unknown_field: value +`), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, err := loadRepoHookItems(hooksPath, config.StaticDefaults().Runtime.Hooks) + if err == nil { + t.Fatal("expected unknown field error") + } + + // 分支:item 校验失败(空 id)。 + if err := os.WriteFile(hooksPath, []byte(` +hooks: + items: + - point: before_tool_call +`), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + _, err = loadRepoHookItems(hooksPath, config.StaticDefaults().Runtime.Hooks) + if err == nil { + t.Fatal("expected validation error for empty id") + } +} + +func TestBuildRepoHookExecutorForWorkspaceErrorPaths(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + + // 分支:loadRepoHookItems 解析失败(trust 已通过但 hooks 文件损坏)。 + workspace := filepath.Join(homeDir, "bad-hooks") + if err := os.MkdirAll(filepath.Join(workspace, ".neocode"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + if err := os.WriteFile(filepath.Join(workspace, ".neocode", "hooks.yaml"), []byte(`not: valid: yaml: [`), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + storePath := resolveTrustedWorkspacesPath() + if err := os.MkdirAll(filepath.Dir(storePath), 0o755); err != nil { + t.Fatalf("mkdir trust store dir: %v", err) + } + rawStore, err := json.Marshal(trustedWorkspaceStore{ + Version: repoHooksTrustStoreVersion, + Workspaces: []string{workspace}, + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(storePath, rawStore, 0o644); err != nil { + t.Fatalf("write trust store: %v", err) + } + + service := &Service{events: make(chan RuntimeEvent, 16)} + _, err = buildRepoHookExecutorForWorkspace(service, workspace, config.StaticDefaults().Runtime.Hooks) + if err == nil { + t.Fatal("expected error from corrupted hooks file") + } + + // 分支:hooksPath 解析失败(workspace 为空)。 + service2 := &Service{events: make(chan RuntimeEvent, 16)} + exec, err := buildRepoHookExecutorForWorkspace(service2, " ", config.StaticDefaults().Runtime.Hooks) + if err != nil { + t.Fatalf("buildRepoHookExecutorForWorkspace(blank) error = %v", err) + } + if exec != nil { + t.Fatal("expected nil executor for blank workspace") + } +} + +func TestValidateRepoHookItemCommandKindBranches(t *testing.T) { + t.Parallel() + + // 分支:kind=command 且 params.command 存在时通过。 + item := config.RuntimeHookItemConfig{ + ID: "cmd-ok", + Point: "before_tool_call", + Scope: "repo", + Kind: "command", + Mode: "sync", + TimeoutSec: 2, + FailurePolicy: "warn_only", + Params: map[string]any{"command": []any{"echo", "ok"}}, + } + if err := validateRepoHookItem(item); err != nil { + t.Fatalf("validateRepoHookItem(command with params) error = %v", err) + } + + // 分支:kind=command 但 params.command 为空。 + item2 := item.Clone() + item2.Params = map[string]any{} + if err := validateRepoHookItem(item2); err == nil { + t.Fatal("expected error for command without params.command") + } +} + +func TestBuildRepoHookExecutorForWorkspaceEmptyHooks(t *testing.T) { + homeDir := t.TempDir() + t.Setenv("HOME", homeDir) + workspace := filepath.Join(homeDir, "empty-hooks") + if err := os.MkdirAll(filepath.Join(workspace, ".neocode"), 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + // hooks 文件存在但所有 item 都 disabled。 + if err := os.WriteFile(filepath.Join(workspace, ".neocode", "hooks.yaml"), []byte(` +hooks: + items: + - id: disabled-hook + enabled: false + point: before_tool_call + scope: repo + kind: builtin + mode: sync + handler: add_context_note + params: + note: skip +`), 0o644); err != nil { + t.Fatalf("write: %v", err) + } + storePath := resolveTrustedWorkspacesPath() + if err := os.MkdirAll(filepath.Dir(storePath), 0o755); err != nil { + t.Fatalf("mkdir trust store dir: %v", err) + } + rawStore, err := json.Marshal(trustedWorkspaceStore{ + Version: repoHooksTrustStoreVersion, + Workspaces: []string{workspace}, + }) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := os.WriteFile(storePath, rawStore, 0o644); err != nil { + t.Fatalf("write trust store: %v", err) + } + + service := &Service{events: make(chan RuntimeEvent, 64)} + exec, err := buildRepoHookExecutorForWorkspace(service, workspace, config.StaticDefaults().Runtime.Hooks) + if err != nil { + t.Fatalf("buildRepoHookExecutorForWorkspace() error = %v", err) + } + if exec != nil { + t.Fatal("expected nil executor when all hooks are disabled") + } + + events := collectRuntimeEvents(service.Events()) + if !containsRuntimeEventType(events, EventRepoHooksLoaded) { + t.Fatalf("expected %s event", EventRepoHooksLoaded) + } +} + +func TestResolveTrustedWorkspacesPathHomeDirFallback(t *testing.T) { + // 分支:HOME 为相对路径,触发 UserHomeDir fallback。 + originalHome := os.Getenv("HOME") + t.Setenv("HOME", "relative-home-dir") + t.Cleanup(func() { os.Setenv("HOME", originalHome) }) + + path := resolveTrustedWorkspacesPath() + if !strings.Contains(path, ".neocode") { + t.Fatalf("expected .neocode in path, got: %s", path) + } +} diff --git a/internal/runtime/repository_context.go b/internal/runtime/repository_context.go index 143705c06..3377e15cc 100644 --- a/internal/runtime/repository_context.go +++ b/internal/runtime/repository_context.go @@ -6,8 +6,8 @@ import ( "strings" agentcontext "neo-code/internal/context" - "neo-code/internal/repository" providertypes "neo-code/internal/provider/types" + "neo-code/internal/repository" ) // buildRepositoryContext 返回最小 Git 摘要(迁移期保留),不再自动注入 changed-files 或 retrieval。 diff --git a/internal/runtime/run.go b/internal/runtime/run.go index e3b969ba6..82cd9ed30 100644 --- a/internal/runtime/run.go +++ b/internal/runtime/run.go @@ -35,6 +35,8 @@ const ( maxAcceptanceContinues = 3 ) +const historicalImageOmittedForModel = "[历史图片已省略:当前模型不支持图片输入]" + // computeToolSignature 计算单轮执行的工具签名,用于循环检测。 func computeToolSignature(calls []providertypes.ToolCall) string { if len(calls) == 0 { @@ -167,6 +169,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { state.taskID = strings.TrimSpace(input.TaskID) state.agentID = strings.TrimSpace(input.AgentID) state.userGoal = strings.TrimSpace(partsrender.RenderDisplayParts(input.Parts)) + state.currentInputParts = providertypes.CloneParts(input.Parts) if input.CapabilityToken != nil { token := input.CapabilityToken.Normalize() state.capabilityToken = &token @@ -203,6 +206,7 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { }) return s.handleRunError(errors.New(findHookBlockMessage(submitHookOutput))) } + input.Parts = applyCommandHookUpdateInput(submitHookOutput, input.Parts) if err := s.appendUserMessageAndSave(ctx, &state, input.Parts); err != nil { return s.handleRunError(err) } @@ -212,6 +216,13 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.maybeAppendPlanBootstrapReminder(ctx, &state); err != nil { return s.handleRunError(err) } + resolvedProviderForImages, modelForImages, err := resolveCompactProviderSelection(state.session, initialCfg) + if err != nil { + return s.handleRunError(err) + } + if err := rejectUnsupportedCurrentImageInput(modelForImages, resolvedProviderForImages.Models, input.Parts); err != nil { + return s.handleRunError(err) + } s.emitRuntimeSnapshotUpdated(ctx, &state, "session_start") s.updateResumeCheckpoint(ctx, &state, "plan", "") @@ -380,6 +391,10 @@ func (s *Service) Run(ctx context.Context, input UserInput) (err error) { if err := s.appendAssistantMessageOnlyAndSave(ctx, &state, planMessage); err != nil { return s.handleRunError(err) } + s.emitRunScoped(ctx, EventPlanUpdated, &state, PlanUpdatedPayload{ + CurrentPlan: nextPlan.Clone(), + DisplayText: resolvePlanDisplayText(planOutput, nextPlan.Spec), + }) s.emitRunScoped(ctx, EventAgentDone, &state, planMessage) return nil } @@ -566,9 +581,7 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState SessionOutputTokens: state.session.TokenOutputTotal, }, Compact: agentcontext.CompactOptions{ - DisableMicroCompact: cfg.Context.Compact.MicroCompactDisabled, - MicroCompactRetainedToolSpans: cfg.Context.Compact.MicroCompactRetainedToolSpans, - ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, + ReadTimeMaxMessageSpans: cfg.Context.Compact.ReadTimeMaxMessageSpans, }, }) if err != nil { @@ -626,7 +639,15 @@ func (s *Service) prepareTurnBudgetSnapshot(ctx context.Context, state *runState budgetCfg.SelectedProvider = resolvedProvider.Name budgetCfg.CurrentModel = model promptBudget, budgetSource, contextWindow := s.resolvePromptBudget(ctx, budgetCfg) - requestMessages := append([]providertypes.Message(nil), builtContext.Messages...) + requestMessages, err := projectImagesForModelRequest( + model, + resolvedProvider.Models, + builtContext.Messages, + state.currentInputParts, + ) + if err != nil { + return TurnBudgetSnapshot{}, false, err + } thinkingCfg, thinkingErr := resolveThinkingConfig( modelCapabilityHintsForRequest(model, resolvedProvider.Models), state.thinkingOverride, @@ -835,7 +856,7 @@ func (s *Service) emitToolDiffs(ctx context.Context, state *runState, summary to } // buildToolDiffPayload 将工具结果 metadata 中的 diff 信息组装成 ToolDiffPayload。 -// 多文件工具(filesystem_move_file 等)使用 Files+Diffs 多路径字段; +// 使用 Files+Diffs 或 FilePath/Diff/WasNew 字段; // 其他写工具继续填充兼容字段 FilePath/Diff/WasNew,保持现有消费者不破。 // FileChange.Kind 优先取 entry.Kind(toolexec 收集层填充),缺失时回退到 WasNew 二分以兼容旧 metadata。 func buildToolDiffPayload(result tools.ToolResult) (ToolDiffPayload, bool) { @@ -1110,6 +1131,91 @@ func hasUserInputParts(parts []providertypes.ContentPart) bool { return false } +// projectImagesForModelRequest 根据当前模型能力生成 provider 可见的图片请求视图,历史图片降级不污染持久化消息。 +func projectImagesForModelRequest( + model string, + models []providertypes.ModelDescriptor, + messages []providertypes.Message, + currentInputParts []providertypes.ContentPart, +) ([]providertypes.Message, error) { + if !messagesContainImages(messages) { + return cloneMessagesForPersistence(messages), nil + } + if partsContainImages(currentInputParts) { + if err := rejectUnsupportedCurrentImageInput(model, models, currentInputParts); err != nil { + return nil, err + } + return cloneMessagesForPersistence(messages), nil + } + + hints := modelCapabilityHintsForRequest(model, models) + if hints.ImageInput == providertypes.ModelCapabilityStateSupported { + return cloneMessagesForPersistence(messages), nil + } + return projectHistoricalImagesAsText(messages), nil +} + +// rejectUnsupportedCurrentImageInput 在请求发送前拦截明确不支持本次图片输入的模型,避免上游返回协议级 400。 +func rejectUnsupportedCurrentImageInput( + model string, + models []providertypes.ModelDescriptor, + parts []providertypes.ContentPart, +) error { + if !partsContainImages(parts) { + return nil + } + hints := modelCapabilityHintsForRequest(model, models) + if hints.ImageInput != providertypes.ModelCapabilityStateUnsupported { + return nil + } + return fmt.Errorf( + "runtime: model %q does not support image input; switch to a multimodal model or remove the image", + strings.TrimSpace(model), + ) +} + +// projectHistoricalImagesAsText 把 provider 请求中的历史图片分片替换为文本占位,保留消息顺序和其他文本内容。 +func projectHistoricalImagesAsText(messages []providertypes.Message) []providertypes.Message { + projected := cloneMessagesForPersistence(messages) + for messageIndex := range projected { + if len(projected[messageIndex].Parts) == 0 { + continue + } + parts := make([]providertypes.ContentPart, 0, len(projected[messageIndex].Parts)) + for _, part := range projected[messageIndex].Parts { + if part.Kind == providertypes.ContentPartImage && part.Image != nil { + parts = append(parts, providertypes.NewTextPart(historicalImageOmittedForModel)) + continue + } + parts = append(parts, part) + } + projected[messageIndex].Parts = parts + } + return projected +} + +// partsContainImages 判断当前输入分片中是否包含图片分片。 +func partsContainImages(parts []providertypes.ContentPart) bool { + for _, part := range parts { + if part.Kind == providertypes.ContentPartImage && part.Image != nil { + return true + } + } + return false +} + +// messagesContainImages 判断上下文消息中是否包含图片分片。 +func messagesContainImages(messages []providertypes.Message) bool { + for _, message := range messages { + for _, part := range message.Parts { + if part.Kind == providertypes.ContentPartImage && part.Image != nil { + return true + } + } + } + return false +} + // sessionTitleFromParts 从输入 parts 中提取一个合适的会话标题。 func sessionTitleFromParts(parts []providertypes.ContentPart) string { for _, part := range parts { diff --git a/internal/runtime/runtime.go b/internal/runtime/runtime.go index 2761784e8..af69c3e4f 100644 --- a/internal/runtime/runtime.go +++ b/internal/runtime/runtime.go @@ -75,6 +75,7 @@ type UserInput struct { // UserImageInput 表示用户输入中附带的单个图片引用(路径 + MIME)。 type UserImageInput struct { Path string + AssetID string MimeType string } @@ -258,12 +259,8 @@ func NewWithFactory( toolManager = tools.NewRegistry() } if contextBuilder == nil { - contextBuilder = agentcontext.NewConfiguredBuilder(agentcontext.MicroCompactConfig{ - Policies: toolManager, - Summarizers: toolManager, - }) + contextBuilder = agentcontext.NewConfiguredBuilder() } - service := &Service{ configManager: configManager, sessionStore: sessionStore, diff --git a/internal/runtime/runtime_remaining_branches_test.go b/internal/runtime/runtime_remaining_branches_test.go index ec084c1d1..1cf5a215d 100644 --- a/internal/runtime/runtime_remaining_branches_test.go +++ b/internal/runtime/runtime_remaining_branches_test.go @@ -32,14 +32,6 @@ func (m *callbackToolManager) ListAvailableSpecs(ctx context.Context, input tool return nil, ctx.Err() } -func (m *callbackToolManager) MicroCompactPolicy(name string) tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (m *callbackToolManager) MicroCompactSummarizer(name string) tools.ContentSummarizer { - return nil -} - func (m *callbackToolManager) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { if m.executeFn != nil { return m.executeFn(ctx, input) diff --git a/internal/runtime/runtime_test.go b/internal/runtime/runtime_test.go index 4a85d7294..93a63d5e8 100644 --- a/internal/runtime/runtime_test.go +++ b/internal/runtime/runtime_test.go @@ -782,7 +782,6 @@ type stubTool struct { content string isError bool err error - policy tools.MicroCompactPolicy callCount int lastInput tools.ToolCallInput executeFn func(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) @@ -800,10 +799,6 @@ func (t *stubTool) Schema() map[string]any { return map[string]any{"type": "object"} } -func (t *stubTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return t.policy -} - func (t *stubTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { t.callCount++ t.lastInput = input @@ -850,7 +845,6 @@ type stubToolManager struct { result tools.ToolResult err error listErr error - policies map[string]tools.MicroCompactPolicy listCalls int executeCalls int lastInput tools.ToolCallInput @@ -876,19 +870,6 @@ func (m *stubToolManager) ListAvailableSpecs(ctx context.Context, input tools.Sp return append([]providertypes.ToolSpec(nil), m.specs...), nil } -func (m *stubToolManager) MicroCompactPolicy(name string) tools.MicroCompactPolicy { - m.mu.Lock() - defer m.mu.Unlock() - if policy, ok := m.policies[name]; ok { - return policy - } - return tools.MicroCompactPolicyCompact -} - -func (m *stubToolManager) MicroCompactSummarizer(name string) tools.ContentSummarizer { - return nil -} - func (m *stubToolManager) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { m.mu.Lock() m.executeCalls++ @@ -1622,9 +1603,6 @@ func TestServiceRunDelegatesToContextBuilder(t *testing.T) { if builder.lastInput.Metadata.Model == "" { t.Fatalf("expected model to be forwarded to builder metadata") } - if builder.lastInput.Compact.DisableMicroCompact { - t.Fatalf("expected micro compact to stay enabled by default") - } if builder.lastInput.TaskState.Goal != "Finish task state rollout" { t.Fatalf("expected session task state to be forwarded to builder, got %+v", builder.lastInput.TaskState) } @@ -1716,47 +1694,6 @@ func TestServiceRunUsesSessionSelectionForMetadataAndBudget(t *testing.T) { } } -func TestServiceRunCanDisableMicroCompactViaConfig(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - if err := manager.Update(context.Background(), func(cfg *config.Config) error { - cfg.Context.Compact.MicroCompactDisabled = true - return nil - }); err != nil { - t.Fatalf("update config: %v", err) - } - - store := newMemoryStore() - registry := tools.NewRegistry() - registry.Register(&stubTool{name: "filesystem_read_file", content: "default"}) - - builder := &stubContextBuilder{ - buildFn: func(ctx context.Context, input agentcontext.BuildInput) (agentcontext.BuildResult, error) { - return agentcontext.BuildResult{ - SystemPrompt: "delegated prompt", - Messages: append([]providertypes.Message(nil), input.Messages...), - }, nil - }, - } - - scripted := &scriptedProvider{ - responses: []scriptedResponse{{ - Message: providertypes.Message{Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}}, - FinishReason: "stop", - }}, - } - - service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, builder) - if err := service.Run(context.Background(), UserInput{RunID: "run-disable-micro-compact", Parts: []providertypes.ContentPart{providertypes.NewTextPart("hello")}}); err != nil { - t.Fatalf("Run() error = %v", err) - } - - if !builder.lastInput.Compact.DisableMicroCompact { - t.Fatalf("expected config to disable micro compact in build input") - } -} - func TestServiceRunPersistsSessionProviderAndModel(t *testing.T) { t.Parallel() @@ -1786,131 +1723,6 @@ func TestServiceRunPersistsSessionProviderAndModel(t *testing.T) { } } -func TestServiceRunDefaultBuilderUsesToolManagerMicroCompactPolicies(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - store := newMemoryStore() - registry := tools.NewRegistry() - registry.Register(&stubTool{name: "preserve_tool", content: "default", policy: tools.MicroCompactPolicyPreserveHistory}) - registry.Register(&stubTool{name: "bash", content: "default"}) - registry.Register(&stubTool{name: "webfetch", content: "default"}) - - session := agentsession.New("preserve history") - session.ID = "session-preserve-history" - session.Messages = []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "preserve_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("preserved result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - } - store.sessions[session.ID] = cloneSession(session) - - scripted := &scriptedProvider{ - responses: []scriptedResponse{{ - Message: providertypes.Message{Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}}, - FinishReason: "stop", - }}, - } - - service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, nil) - if err := service.Run(context.Background(), UserInput{ - SessionID: session.ID, - RunID: "run-preserve-history-policy", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}, - }); err != nil { - t.Fatalf("Run() error = %v", err) - } - - if len(scripted.requests) != 1 { - t.Fatalf("expected 1 provider request, got %d", len(scripted.requests)) - } - if got := renderPartsForTest(scripted.requests[0].Messages[2].Parts); got != "preserved result" { - t.Fatalf("expected preserved tool result to remain visible, got %q", got) - } -} - -func TestServiceRunDefaultBuilderUsesGenericToolManagerMicroCompactPolicies(t *testing.T) { - t.Parallel() - - manager := newRuntimeConfigManager(t) - store := newMemoryStore() - toolManager := &stubToolManager{ - policies: map[string]tools.MicroCompactPolicy{ - "preserve_tool": tools.MicroCompactPolicyPreserveHistory, - }, - } - - session := agentsession.New("preserve history by manager") - session.ID = "session-preserve-history-manager" - session.Messages = []providertypes.Message{ - {Role: providertypes.RoleUser, Parts: []providertypes.ContentPart{providertypes.NewTextPart("older user")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-1", Name: "preserve_tool", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-1", Parts: []providertypes.ContentPart{providertypes.NewTextPart("preserved result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-2", Name: "bash", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-2", Parts: []providertypes.ContentPart{providertypes.NewTextPart("recent bash result")}}, - { - Role: providertypes.RoleAssistant, - ToolCalls: []providertypes.ToolCall{ - {ID: "call-3", Name: "webfetch", Arguments: "{}"}, - }, - }, - {Role: providertypes.RoleTool, ToolCallID: "call-3", Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest webfetch result")}}, - } - store.sessions[session.ID] = cloneSession(session) - - scripted := &scriptedProvider{ - responses: []scriptedResponse{{ - Message: providertypes.Message{Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("done")}}, - FinishReason: "stop", - }}, - } - - service := NewWithFactory(manager, toolManager, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) - if err := service.Run(context.Background(), UserInput{ - SessionID: session.ID, - RunID: "run-preserve-history-generic-manager", - Parts: []providertypes.ContentPart{providertypes.NewTextPart("latest explicit instruction")}, - }); err != nil { - t.Fatalf("Run() error = %v", err) - } - - if len(scripted.requests) != 1 { - t.Fatalf("expected 1 provider request, got %d", len(scripted.requests)) - } - if got := renderPartsForTest(scripted.requests[0].Messages[2].Parts); got != "preserved result" { - t.Fatalf("expected preserved tool result to remain visible, got %q", got) - } -} - func TestServiceRunFailurePreservesExistingSessionProviderAndModel(t *testing.T) { t.Parallel() @@ -3943,6 +3755,7 @@ func TestServiceRunPlanModePersistsDraftPlan(t *testing.T) { }); err != nil { t.Fatalf("Run() error = %v", err) } + events := collectRuntimeEvents(service.Events()) saved := onlySession(t, store) if saved.AgentMode != agentsession.AgentModePlan { @@ -3978,9 +3791,26 @@ func TestServiceRunPlanModePersistsDraftPlan(t *testing.T) { if got := renderPartsForTest(saved.Messages[2].Parts); !strings.Contains(got, "目标") { t.Fatalf("expected rendered plan content, got %q", got) } + var planEvent RuntimeEvent + for _, event := range events { + if event.Type == EventPlanUpdated { + planEvent = event + break + } + } + if planEvent.Type != EventPlanUpdated { + t.Fatalf("expected %s event, got events %+v", EventPlanUpdated, eventTypes(events)) + } + payload, ok := planEvent.Payload.(PlanUpdatedPayload) + if !ok { + t.Fatalf("plan event payload = %T, want PlanUpdatedPayload", planEvent.Payload) + } + if payload.CurrentPlan == nil || payload.CurrentPlan.Spec.Goal != "为 runtime 引入 plan/build 模式" { + t.Fatalf("unexpected plan event payload: %+v", payload.CurrentPlan) + } } -func TestServiceRunPlanModeShowsExplanationTextOutsidePlanningJSON(t *testing.T) { +func TestServiceRunPlanModePersistsCanonicalMarkdownInsteadOfPlanningJSON(t *testing.T) { t.Parallel() manager := newRuntimeConfigManager(t) @@ -4032,8 +3862,12 @@ func TestServiceRunPlanModeShowsExplanationTextOutsidePlanningJSON(t *testing.T) if strings.Contains(got, "\"plan_spec\"") { t.Fatalf("expected persisted assistant text to strip planning JSON, got %q", got) } - if !strings.Contains(got, "先确认范围") || !strings.Contains(got, "继续执行") { - t.Fatalf("expected prose explanation to be preserved, got %q", got) + if strings.Contains(got, "先确认范围") || strings.Contains(got, "继续执行") { + t.Fatalf("expected model prose to be replaced by canonical markdown, got %q", got) + } + if !strings.Contains(got, "### 目标") || !strings.Contains(got, "Preserve prose around planning JSON") || + !strings.Contains(got, "### 实施步骤") || !strings.Contains(got, "- persist plan") { + t.Fatalf("expected canonical markdown plan, got %q", got) } } @@ -4926,10 +4760,280 @@ func TestServiceRunCompactedSessionRequestsRestoreAlignment(t *testing.T) { } } +func TestRunRejectsImageInputForUnsupportedModelBeforeProviderBuild(t *testing.T) { + manager := newRuntimeConfigManager(t) + configureCurrentModelImageInputForTest(t, manager, providertypes.ModelCapabilityStateUnsupported) + + store := newMemoryStore() + factory := &scriptedProviderFactory{provider: &scriptedProvider{}} + service := NewWithFactory(manager, tools.NewRegistry(), store, factory, &stubContextBuilder{}) + + err := service.Run(context.Background(), UserInput{ + RunID: "run-image-unsupported", + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("look at this"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }) + if err == nil || !strings.Contains(err.Error(), "does not support image input") { + t.Fatalf("Run() error = %v, want image input unsupported", err) + } + if factory.calls != 0 { + t.Fatalf("provider build calls = %d, want 0", factory.calls) + } +} + +func TestRunProjectsHistoricalImagesForUnsupportedModelRequest(t *testing.T) { + manager := newRuntimeConfigManager(t) + configureCurrentModelImageInputForTest(t, manager, providertypes.ModelCapabilityStateUnsupported) + + store := newMemoryStore() + seed := agentsession.New("image history") + seed.ID = "session-image-history" + seed.Messages = []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("old image"), + providertypes.NewSessionAssetImagePart("asset-old", "image/png"), + }, + }, + {Role: providertypes.RoleAssistant, Parts: []providertypes.ContentPart{providertypes.NewTextPart("图片已收到")}}, + } + store.sessions[seed.ID] = cloneSession(seed) + + scripted := &scriptedProvider{ + streams: [][]providertypes.StreamEvent{{ + providertypes.NewTextDeltaStreamEvent("done"), + providertypes.NewMessageDoneStreamEvent("", nil), + }}, + } + service := NewWithFactory(manager, tools.NewRegistry(), store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + + if err := service.Run(context.Background(), UserInput{ + SessionID: seed.ID, + RunID: "run-project-historical-image", + Parts: []providertypes.ContentPart{providertypes.NewTextPart("continue without image")}, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + if len(scripted.requests) != 1 { + t.Fatalf("provider requests = %d, want 1", len(scripted.requests)) + } + if messagesContainImages(scripted.requests[0].Messages) { + t.Fatalf("provider request still contains image: %+v", scripted.requests[0].Messages) + } + if got := renderPartsForTest(scripted.requests[0].Messages[0].Parts); !strings.Contains(got, historicalImageOmittedForModel) { + t.Fatalf("projected historical image text = %q, want placeholder", got) + } + + saved, err := store.Load(context.Background(), seed.ID) + if err != nil { + t.Fatalf("load saved session: %v", err) + } + if !messagesContainImages(saved.Messages[:1]) { + t.Fatalf("saved historical message lost image: %+v", saved.Messages[:1]) + } +} + +func TestProjectImagesForModelRequestCapabilityStates(t *testing.T) { + imageMessages := []providertypes.Message{{ + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }} + currentText := []providertypes.ContentPart{providertypes.NewTextPart("hello")} + currentImage := []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-new", "image/png")} + + tests := []struct { + name string + model string + models []providertypes.ModelDescriptor + current []providertypes.ContentPart + wantErr bool + wantImage bool + wantOmitted bool + }{ + { + name: "unsupported historical image is omitted", + model: "model-a", + models: []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }}}, + current: currentText, + wantOmitted: true, + }, + { + name: "unknown historical image is omitted", + model: "model-a", + models: []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnknown, + }}}, + current: currentText, + wantOmitted: true, + }, + { + name: "missing model historical image is omitted", + model: "missing", + models: []providertypes.ModelDescriptor{{ID: "other"}}, + current: currentText, + wantOmitted: true, + }, + { + name: "supported historical image is retained", + model: "model-a", + models: []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateSupported, + }}}, + current: currentText, + wantImage: true, + }, + { + name: "unsupported current image rejects", + model: "model-a", + models: []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }}}, + current: currentImage, + wantErr: true, + }, + { + name: "unknown current image is retained", + model: "model-a", + models: []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnknown, + }}}, + current: currentImage, + wantImage: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + projected, err := projectImagesForModelRequest(tt.model, tt.models, imageMessages, tt.current) + if tt.wantErr && err == nil { + t.Fatal("expected unsupported image input error") + } + if !tt.wantErr && err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.wantErr { + return + } + if got := messagesContainImages(projected); got != tt.wantImage { + t.Fatalf("messagesContainImages(projected) = %t, want %t; projected=%+v", got, tt.wantImage, projected) + } + if got := strings.Contains(renderPartsForTest(projected[0].Parts), historicalImageOmittedForModel); got != tt.wantOmitted { + t.Fatalf("placeholder present = %t, want %t; projected=%+v", got, tt.wantOmitted, projected) + } + if !messagesContainImages(imageMessages) { + t.Fatalf("original messages were mutated: %+v", imageMessages) + } + }) + } +} + +// TestProjectImagesForModelRequestEmptyMessages 验证空消息列表不会 panic 并正确返回空结果。 +func TestProjectImagesForModelRequestEmptyMessages(t *testing.T) { + t.Parallel() + models := []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }}} + projected, err := projectImagesForModelRequest("model-a", models, nil, []providertypes.ContentPart{providertypes.NewTextPart("hello")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(projected) != 0 { + t.Fatalf("expected empty projected messages, got %d", len(projected)) + } +} + +// TestProjectImagesForModelRequestMixedMessages 验证混合图片和文本的消息中只有图片被投影降级。 +func TestProjectImagesForModelRequestMixedMessages(t *testing.T) { + t.Parallel() + messages := []providertypes.Message{ + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("text before image")}, + }, + { + Role: providertypes.RoleUser, + Parts: []providertypes.ContentPart{providertypes.NewSessionAssetImagePart("asset-1", "image/png")}, + }, + { + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("response")}, + }, + } + models := []providertypes.ModelDescriptor{{ID: "model-a", CapabilityHints: providertypes.ModelCapabilityHints{ + ImageInput: providertypes.ModelCapabilityStateUnsupported, + }}} + projected, err := projectImagesForModelRequest("model-a", models, messages, []providertypes.ContentPart{providertypes.NewTextPart("current text")}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(projected) != 3 { + t.Fatalf("expected 3 messages, got %d", len(projected)) + } + // 第一条消息应保持纯文本不变 + if projected[0].Parts[0].Kind != providertypes.ContentPartText { + t.Fatalf("first message should be text, got %s", projected[0].Parts[0].Kind) + } + // 第二条消息中的图片应被替换为占位文本 + if messagesContainImages(projected) { + t.Fatal("expected no images in projected messages for unsupported model") + } + if !strings.Contains(renderPartsForTest(projected[1].Parts), historicalImageOmittedForModel) { + t.Fatal("expected historical image omitted placeholder in second message") + } + // 第三条消息应保持不变 + if projected[2].Parts[0].Kind != providertypes.ContentPartText { + t.Fatalf("third message should be text, got %s", projected[2].Parts[0].Kind) + } + // 原始消息不应被修改 + if !messagesContainImages(messages) { + t.Fatal("original messages should still contain images") + } +} + func newRuntimeConfigManager(t *testing.T) *config.Manager { return newRuntimeConfigManagerWithProviderEnvs(t, nil) } +func configureCurrentModelImageInputForTest( + t *testing.T, + manager *config.Manager, + state providertypes.ModelCapabilityState, +) { + t.Helper() + providerName := config.OpenAIName + modelID := "gpt-4o" + apiKeyEnv := config.OpenAIDefaultAPIKeyEnv + if state == providertypes.ModelCapabilityStateUnsupported { + providerName = config.QiniuName + modelID = config.QiniuDefaultModel + apiKeyEnv = config.QiniuDefaultAPIKeyEnv + } + t.Setenv(apiKeyEnv, "test-key") + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + providerIndex := -1 + for index := range cfg.Providers { + if cfg.Providers[index].Name == providerName { + providerIndex = index + break + } + } + if providerIndex < 0 { + return fmt.Errorf("test config has no provider %q", providerName) + } + cfg.SelectedProvider = providerName + cfg.CurrentModel = modelID + cfg.Providers[providerIndex].Model = modelID + return nil + }); err != nil { + t.Fatalf("configure image input capability: %v", err) + } +} + func newRuntimeConfigManagerWithProviderEnvs(t *testing.T, providerEnvs map[string]string) *config.Manager { t.Helper() @@ -6352,6 +6456,90 @@ func TestServiceRunAllowsAfterProactiveCompactWhenEstimateAdvisory(t *testing.T) } } +func TestServiceRunAllowsImageRequestWithinProjectedBudget(t *testing.T) { + t.Parallel() + + manager := newRuntimeConfigManager(t) + if err := manager.Update(context.Background(), func(cfg *config.Config) error { + cfg.Context.Budget.PromptBudget = 5000 + cfg.Context.Budget.FallbackPromptBudget = 5000 + return nil + }); err != nil { + t.Fatalf("update config: %v", err) + } + + store := newMemoryStore() + registry := tools.NewRegistry() + scripted := &scriptedProvider{ + estimateFn: func(ctx context.Context, req providertypes.GenerateRequest) (providertypes.BudgetEstimate, error) { + _ = ctx + tokens, err := provider.EstimateProjectedInputTokens(req, provider.ResolveRequestModel(req, "gpt-4.1")) + if err != nil { + return providertypes.BudgetEstimate{}, err + } + return providertypes.BudgetEstimate{ + EstimatedInputTokens: tokens, + EstimateSource: provider.EstimateSourceLocal, + GatePolicy: provider.EstimateGateAdvisory, + }, nil + }, + responses: []scriptedResponse{ + { + Message: providertypes.Message{ + Role: providertypes.RoleAssistant, + Parts: []providertypes.ContentPart{providertypes.NewTextPart("图片已收到")}, + }, + FinishReason: "stop", + }, + }, + } + + service := NewWithFactory(manager, registry, store, &scriptedProviderFactory{provider: scripted}, &stubContextBuilder{}) + service.compactRunner = &stubCompactRunner{} + + if err := service.Run(context.Background(), UserInput{ + RunID: "run-budget-image-allow", + Parts: []providertypes.ContentPart{ + providertypes.NewTextPart("describe"), + providertypes.NewSessionAssetImagePart("asset-1", "image/png"), + }, + }); err != nil { + t.Fatalf("Run() error = %v", err) + } + + if scripted.callCount != 1 { + t.Fatalf("expected provider Generate to be called once, got %d", scripted.callCount) + } + if compactRunner := service.compactRunner.(*stubCompactRunner); len(compactRunner.calls) != 0 { + t.Fatalf("expected no proactive compact for projected image estimate, got %d calls", len(compactRunner.calls)) + } + + events := collectRuntimeEvents(service.Events()) + var budgetPayload *BudgetCheckedPayload + for _, event := range events { + if event.Type != EventBudgetChecked { + continue + } + payload, ok := event.Payload.(BudgetCheckedPayload) + if !ok { + t.Fatalf("expected BudgetCheckedPayload, got %T", event.Payload) + } + budgetPayload = &payload + break + } + if budgetPayload == nil { + t.Fatalf("expected budget_checked event, got %+v", events) + } + if budgetPayload.Action != string(controlplane.TurnBudgetActionAllow) || + budgetPayload.Reason != controlplane.BudgetDecisionReasonWithinBudget { + t.Fatalf("unexpected budget decision: %+v", budgetPayload) + } + if budgetPayload.EstimatedInputTokens <= provider.DefaultImageInputTokenEstimate || + budgetPayload.EstimatedInputTokens >= budgetPayload.PromptBudget { + t.Fatalf("unexpected projected image estimate: %+v", budgetPayload) + } +} + func TestServiceRunStopsAfterNoOpProactiveCompactWhenEstimateGateable(t *testing.T) { t.Parallel() diff --git a/internal/runtime/state.go b/internal/runtime/state.go index e7993eff4..ff34d73df 100644 --- a/internal/runtime/state.go +++ b/internal/runtime/state.go @@ -38,6 +38,7 @@ type runState struct { maxTurnsReached bool maxTurnsLimit int userGoal string + currentInputParts []providertypes.ContentPart pendingSystemReminder string acceptanceContinueCount int toolTimeoutBackoff map[string]int diff --git a/internal/runtime/todo_bootstrap.go b/internal/runtime/todo_bootstrap.go index 3f05f9330..75c1c1daa 100644 --- a/internal/runtime/todo_bootstrap.go +++ b/internal/runtime/todo_bootstrap.go @@ -63,8 +63,9 @@ plan_bootstrap_required: You are in plan mode but no current plan exists. Before research, analysis, or conversational response, you MUST complete the following: 1. Research the codebase as needed using read-only tools. -2. Output a JSON object containing "plan_spec" and "summary_candidate" that defines the current plan. -3. Focus plan_spec on goal, steps, constraints, and open_questions. Do not create execution todos in plan mode. +2. Output the visible plan as Markdown first, using short sections for goal, steps, constraints, and open questions. +3. After the Markdown, include one compact machine-readable JSON object containing "plan_spec" and "summary_candidate". Put this JSON inside an HTML comment, not in a fenced code block. +4. Focus plan_spec on goal, steps, constraints, and open_questions. Do not create execution todos in plan mode. Do not end this turn without producing a plan.` diff --git a/internal/runtime/tool_diff_helpers_test.go b/internal/runtime/tool_diff_helpers_test.go index e4071a408..f9864f805 100644 --- a/internal/runtime/tool_diff_helpers_test.go +++ b/internal/runtime/tool_diff_helpers_test.go @@ -33,7 +33,7 @@ func TestBuildToolDiffPayload(t *testing.T) { t.Run("multi file payload", func(t *testing.T) { result := tools.ToolResult{ - Name: tools.ToolNameFilesystemMoveFile, + Name: tools.ToolNameFilesystemWriteFile, ToolCallID: "call-2", Metadata: map[string]any{ "tool_diffs": []map[string]any{ @@ -61,7 +61,7 @@ func TestBuildToolDiffPayload(t *testing.T) { t.Run("multi file kind from metadata wins over WasNew fallback", func(t *testing.T) { result := tools.ToolResult{ - Name: tools.ToolNameFilesystemMoveFile, + Name: tools.ToolNameFilesystemWriteFile, ToolCallID: "call-move", Metadata: map[string]any{ "tool_diffs": []map[string]any{ @@ -88,7 +88,7 @@ func TestBuildToolDiffPayload(t *testing.T) { t.Run("multi file filters unchanged copy source", func(t *testing.T) { result := tools.ToolResult{ - Name: tools.ToolNameFilesystemCopyFile, + Name: tools.ToolNameFilesystemWriteFile, ToolCallID: "call-copy", Metadata: map[string]any{ "tool_diffs": []map[string]any{ @@ -112,7 +112,7 @@ func TestBuildToolDiffPayload(t *testing.T) { t.Run("multi file delete metadata preserves deleted kind", func(t *testing.T) { result := tools.ToolResult{ - Name: tools.ToolNameFilesystemRemoveDir, + Name: tools.ToolNameFilesystemWriteFile, ToolCallID: "call-rm", Metadata: map[string]any{ "tool_diffs": []map[string]any{ @@ -144,7 +144,7 @@ func TestBuildToolDiffPayload(t *testing.T) { } func TestToolExecutionHelperFunctions(t *testing.T) { - t.Run("toolCallTouchedPaths covers write and move payloads", func(t *testing.T) { + t.Run("toolCallTouchedPaths extracts path", func(t *testing.T) { writePaths := toolCallTouchedPaths(providertypes.ToolCall{ Name: tools.ToolNameFilesystemWriteFile, Arguments: `{"path":" docs/readme.md "}`, @@ -153,16 +153,8 @@ func TestToolExecutionHelperFunctions(t *testing.T) { t.Fatalf("write toolCallTouchedPaths() = %#v", writePaths) } - movePaths := toolCallTouchedPaths(providertypes.ToolCall{ - Name: tools.ToolNameFilesystemMoveFile, - Arguments: `{"source_path":"src/a.txt","destination_path":" /tmp/b.txt "}`, - }, "/repo") - if len(movePaths) != 2 || movePaths[0] != "/repo/src/a.txt" || movePaths[1] != "/tmp/b.txt" { - t.Fatalf("move toolCallTouchedPaths() = %#v", movePaths) - } - if got := toolCallTouchedPaths(providertypes.ToolCall{ - Name: tools.ToolNameFilesystemCopyFile, + Name: tools.ToolNameFilesystemWriteFile, Arguments: `{invalid`, }, "/repo"); got != nil { t.Fatalf("malformed toolCallTouchedPaths() = %#v, want nil", got) diff --git a/internal/runtime/toolexec.go b/internal/runtime/toolexec.go index efc3cc71d..8b57a2477 100644 --- a/internal/runtime/toolexec.go +++ b/internal/runtime/toolexec.go @@ -5,11 +5,12 @@ import ( "encoding/json" "errors" "fmt" - "os" "path/filepath" + "regexp" "sort" "strings" "sync" + "unicode" "neo-code/internal/checkpoint" providertypes "neo-code/internal/provider/types" @@ -23,6 +24,12 @@ type indexedToolCall struct { call providertypes.ToolCall } +const hookToolArgumentsPreviewMaxChars = 512 + +var hookToolArgumentsSensitivePattern = regexp.MustCompile( + `(?i)(token|password|secret|api[_-]?key|access[_-]?key|auth)\s*[:=]\s*("[^"]*"|'[^']*'|[^\s]+)`, +) + // executeAssistantToolCalls 并发执行 assistant 返回的全部工具调用并返回结构化执行摘要。 func (s *Service) executeAssistantToolCalls( ctx context.Context, @@ -161,10 +168,11 @@ func (s *Service) executeOneToolCallWithoutPersistence( beforeToolHookOutput := s.runHookPoint(ctx, state, runtimehooks.HookPointBeforeToolCall, runtimehooks.HookContext{ Metadata: map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "tool_arguments": strings.TrimSpace(call.Arguments), - "workdir": strings.TrimSpace(snapshot.Workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments": strings.TrimSpace(call.Arguments), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "workdir": strings.TrimSpace(snapshot.Workdir), }, }) if beforeToolHookOutput.Blocked { @@ -200,7 +208,6 @@ func (s *Service) executeOneToolCallWithoutPersistence( var bashCommand string var bashChangedPaths []string var touchedPaths []string - var removeDirNestedPaths []string if isWrite { touchedPaths = toolCallTouchedPaths(call, snapshot.Workdir) @@ -211,21 +218,6 @@ func (s *Service) executeOneToolCallWithoutPersistence( if s.perEditStore != nil { _, _ = s.perEditStore.CapturePreWrite(p) } - // remove_dir: recursively pre-capture all nested files/dirs. - if strings.EqualFold(strings.TrimSpace(call.Name), tools.ToolNameFilesystemRemoveDir) { - if info, err := os.Stat(p); err == nil && info.IsDir() { - _ = filepath.WalkDir(p, func(path string, d os.DirEntry, err error) error { - if err != nil || path == p { - return nil - } - removeDirNestedPaths = append(removeDirNestedPaths, path) - if s.perEditStore != nil { - _, _ = s.perEditStore.CapturePreWrite(path) - } - return nil - }) - } - } } } } else if isBash && s.perEditStore != nil { @@ -290,17 +282,6 @@ func (s *Service) executeOneToolCallWithoutPersistence( if isWrite && execErr == nil && !result.IsError && s.perEditStore != nil { switch strings.TrimSpace(call.Name) { - case tools.ToolNameFilesystemRemoveDir: - if len(removeDirNestedPaths) > 0 && len(touchedPaths) > 0 { - allPaths := append([]string{touchedPaths[0]}, removeDirNestedPaths...) - _ = s.perEditStore.CapturePostDelete(allPaths) - } else if len(touchedPaths) > 0 { - _ = s.perEditStore.CapturePostDelete(touchedPaths) - } - case tools.ToolNameFilesystemMoveFile: - if len(touchedPaths) > 1 { - _ = s.perEditStore.CapturePostDelete([]string{touchedPaths[0]}) - } case tools.ToolNameFilesystemDeleteFile: if len(touchedPaths) > 0 { _ = s.perEditStore.CapturePostDelete(touchedPaths) @@ -594,42 +575,25 @@ func isFileWriteTool(name string) bool { switch strings.TrimSpace(name) { case tools.ToolNameFilesystemWriteFile, tools.ToolNameFilesystemEdit, - tools.ToolNameFilesystemMoveFile, - tools.ToolNameFilesystemCopyFile, - tools.ToolNameFilesystemDeleteFile, - tools.ToolNameFilesystemCreateDir, - tools.ToolNameFilesystemRemoveDir: + tools.ToolNameFilesystemDeleteFile: return true } return false } // toolCallTouchedPaths 从工具调用参数中提取所有可能被修改的工作区绝对路径。 -// move/copy 同时返回 source 与 destination;其他写工具返回单个 path。 func toolCallTouchedPaths(call providertypes.ToolCall, workdir string) []string { args := strings.TrimSpace(call.Arguments) if args == "" { return nil } - switch strings.TrimSpace(call.Name) { - case tools.ToolNameFilesystemMoveFile, tools.ToolNameFilesystemCopyFile: - var parsed struct { - SourcePath string `json:"source_path"` - DestinationPath string `json:"destination_path"` - } - if err := json.Unmarshal([]byte(args), &parsed); err != nil { - return nil - } - return resolveWorkdirPaths(workdir, parsed.SourcePath, parsed.DestinationPath) - default: - var parsed struct { - Path string `json:"path"` - } - if err := json.Unmarshal([]byte(args), &parsed); err != nil { - return nil - } - return resolveWorkdirPaths(workdir, parsed.Path) + var parsed struct { + Path string `json:"path"` } + if err := json.Unmarshal([]byte(args), &parsed); err != nil { + return nil + } + return resolveWorkdirPaths(workdir, parsed.Path) } // resolveWorkdirPaths 将多个相对/绝对路径解析为工作区绝对路径,丢弃空字符串。 @@ -797,6 +761,122 @@ func summarizeHookResultContent(content string) string { return trimmed[:256] } +// buildToolArgumentsPreview 生成 matcher 可用的参数预览,并对敏感键值执行脱敏。 +func buildToolArgumentsPreview(arguments string) string { + trimmed := strings.TrimSpace(arguments) + if trimmed == "" { + return "" + } + masked := sanitizeHookToolArguments(trimmed) + return truncateHookTextByChars(masked, hookToolArgumentsPreviewMaxChars) +} + +// sanitizeHookToolArguments 优先按 JSON 结构递归脱敏,非 JSON 输入回退为轻量正则脱敏。 +func sanitizeHookToolArguments(arguments string) string { + if masked, ok := sanitizeHookToolArgumentsJSON(arguments); ok { + return masked + } + return hookToolArgumentsSensitivePattern.ReplaceAllString(arguments, `$1=***`) +} + +// sanitizeHookToolArgumentsJSON 尝试解析 JSON 并按敏感键递归替换值。 +func sanitizeHookToolArgumentsJSON(arguments string) (string, bool) { + var decoded any + if err := json.Unmarshal([]byte(arguments), &decoded); err != nil { + return "", false + } + sanitized := maskHookToolArgumentValue(decoded) + encoded, err := json.Marshal(sanitized) + if err != nil { + return "", false + } + return string(encoded), true +} + +// maskHookToolArgumentValue 递归处理 JSON 节点,对敏感键对应的值统一替换为 "***"。 +func maskHookToolArgumentValue(value any) any { + switch typed := value.(type) { + case map[string]any: + masked := make(map[string]any, len(typed)) + for key, item := range typed { + if isSensitiveHookToolArgumentKey(key) { + masked[key] = "***" + continue + } + masked[key] = maskHookToolArgumentValue(item) + } + return masked + case []any: + masked := make([]any, len(typed)) + for index, item := range typed { + masked[index] = maskHookToolArgumentValue(item) + } + return masked + default: + return value + } +} + +// isSensitiveHookToolArgumentKey 判断参数键名是否属于敏感信息字段。 +func isSensitiveHookToolArgumentKey(key string) bool { + tokens := tokenizeHookToolArgumentKey(key) + if len(tokens) == 0 { + return false + } + for index, token := range tokens { + switch token { + case "password", "passwd", "secret", "token", "auth", "authorization": + return true + case "apikey", "accesskey", "authtoken", "accesstoken": + return true + case "api", "access": + if index+1 < len(tokens) && tokens[index+1] == "key" { + return true + } + case "key": + if index > 0 && (tokens[index-1] == "api" || tokens[index-1] == "access") { + return true + } + } + } + return false +} + +// tokenizeHookToolArgumentKey 将参数键拆分为小写词元,兼容 snake/kebab/camelCase。 +func tokenizeHookToolArgumentKey(key string) []string { + trimmed := strings.TrimSpace(key) + if trimmed == "" { + return nil + } + var builder strings.Builder + var prev rune + for _, current := range trimmed { + switch { + case unicode.IsLetter(current) || unicode.IsDigit(current): + if unicode.IsUpper(current) && unicode.IsLower(prev) { + builder.WriteByte(' ') + } + builder.WriteRune(unicode.ToLower(current)) + default: + builder.WriteByte(' ') + } + prev = current + } + return strings.Fields(builder.String()) +} + +// truncateHookTextByChars 按字符长度截断文本,避免 metadata 放大。 +func truncateHookTextByChars(text string, maxChars int) string { + if maxChars <= 0 { + return "" + } + runes := []rune(text) + if len(runes) <= maxChars { + return text + } + return string(runes[:maxChars]) +} + // extractTodoIDsFromPayload 提取 todo 事件快照中的条目 ID,用于冲突事实去重统计。 func extractTodoIDsFromPayload(items []TodoViewItem) []string { if len(items) == 0 { @@ -873,11 +953,12 @@ func (s *Service) emitAfterToolFailureHook( workdir string, ) { afterToolFailureMetadata := map[string]any{ - "tool_call_id": strings.TrimSpace(call.ID), - "tool_name": strings.TrimSpace(call.Name), - "is_error": result.IsError, - "error_class": strings.TrimSpace(result.ErrorClass), - "workdir": strings.TrimSpace(workdir), + "tool_call_id": strings.TrimSpace(call.ID), + "tool_name": strings.TrimSpace(call.Name), + "tool_arguments_preview": buildToolArgumentsPreview(call.Arguments), + "is_error": result.IsError, + "error_class": strings.TrimSpace(result.ErrorClass), + "workdir": strings.TrimSpace(workdir), } if execErr != nil { afterToolFailureMetadata["execution_error"] = strings.TrimSpace(execErr.Error()) diff --git a/internal/runtime/toolexec_preview_test.go b/internal/runtime/toolexec_preview_test.go new file mode 100644 index 000000000..1b2723565 --- /dev/null +++ b/internal/runtime/toolexec_preview_test.go @@ -0,0 +1,81 @@ +package runtime + +import ( + "strings" + "testing" +) + +func TestBuildToolArgumentsPreviewMaskJSONSensitiveFields(t *testing.T) { + t.Parallel() + + raw := `{"api_key":"sk-123","password":"p@ss","nested":{"secret":"abc"},"safe":"ok"}` + preview := buildToolArgumentsPreview(raw) + if strings.Contains(preview, "sk-123") { + t.Fatalf("preview leaked api_key: %q", preview) + } + if strings.Contains(preview, "p@ss") { + t.Fatalf("preview leaked password: %q", preview) + } + if strings.Contains(preview, `"secret":"abc"`) { + t.Fatalf("preview leaked nested secret: %q", preview) + } + if !strings.Contains(preview, `"api_key":"***"`) { + t.Fatalf("preview should mask api_key: %q", preview) + } + if !strings.Contains(preview, `"password":"***"`) { + t.Fatalf("preview should mask password: %q", preview) + } + if !strings.Contains(preview, `"secret":"***"`) { + t.Fatalf("preview should mask nested secret: %q", preview) + } + if !strings.Contains(preview, `"safe":"ok"`) { + t.Fatalf("preview should keep non-sensitive keys: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewMaskNonJSONFallback(t *testing.T) { + t.Parallel() + + preview := buildToolArgumentsPreview(`token=abc password:xyz arg=ok`) + if strings.Contains(preview, "abc") || strings.Contains(preview, "xyz") { + t.Fatalf("preview leaked fallback credentials: %q", preview) + } + if !strings.Contains(preview, "token=***") { + t.Fatalf("preview should mask token in fallback mode: %q", preview) + } + if !strings.Contains(preview, "password=***") { + t.Fatalf("preview should mask password in fallback mode: %q", preview) + } +} + +func TestBuildToolArgumentsPreviewTruncate(t *testing.T) { + t.Parallel() + + raw := strings.Repeat("a", hookToolArgumentsPreviewMaxChars+20) + preview := buildToolArgumentsPreview(raw) + if len([]rune(preview)) != hookToolArgumentsPreviewMaxChars { + t.Fatalf("preview length=%d, want %d", len([]rune(preview)), hookToolArgumentsPreviewMaxChars) + } +} + +func TestIsSensitiveHookToolArgumentKey(t *testing.T) { + t.Parallel() + + cases := []struct { + key string + want bool + }{ + {key: "api_key", want: true}, + {key: "accessKey", want: true}, + {key: "authorization", want: true}, + {key: "auth_token", want: true}, + {key: "password", want: true}, + {key: "author", want: false}, + {key: "tool_name", want: false}, + } + for _, tc := range cases { + if got := isSensitiveHookToolArgumentKey(tc.key); got != tc.want { + t.Fatalf("isSensitiveHookToolArgumentKey(%q)=%v, want %v", tc.key, got, tc.want) + } + } +} diff --git a/internal/runtime/user_hooks.go b/internal/runtime/user_hooks.go index c9557e2e8..86434da4e 100644 --- a/internal/runtime/user_hooks.go +++ b/internal/runtime/user_hooks.go @@ -4,17 +4,14 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net" "net/http" "net/url" "os" - "os/exec" "path/filepath" "runtime" - "slices" "strings" "time" @@ -199,35 +196,45 @@ func buildConfiguredHookSpec( if err := validateConfiguredHookItemForP6Lite(item, scope); err != nil { return runtimehooks.HookSpec{}, err } + point := runtimehooks.HookPoint(strings.TrimSpace(item.Point)) + matcher, err := buildConfiguredHookMatcher(item, point) + if err != nil { + return runtimehooks.HookSpec{}, err + } kind := strings.ToLower(strings.TrimSpace(item.Kind)) specKind := runtimehooks.HookKindFunction specMode := runtimehooks.HookModeSync var ( - handler runtimehooks.HookHandler - err error + handler runtimehooks.HookHandler + buildErr error ) switch kind { case configuredHookKindBuiltin: - handler, err = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) + handler, buildErr = buildUserBuiltinHookHandler(strings.TrimSpace(item.Handler), item.Params, defaultWorkdir) specKind = runtimehooks.HookKindFunction specMode = runtimehooks.HookModeSync case configuredHookKindCommand: - handler, err = buildUserCommandHookHandler(item.Params, defaultWorkdir) + handler, buildErr = buildUserCommandHookHandler( + strings.TrimSpace(item.ID), + point, + item.Params, + defaultWorkdir, + ) specKind = runtimehooks.HookKindCommand specMode = runtimehooks.HookModeSync case configuredHookKindHTTP: - handler, err = buildUserHTTPObserveHookHandler(item) + handler, buildErr = buildUserHTTPObserveHookHandler(item) specKind = runtimehooks.HookKindHTTP specMode = runtimehooks.HookModeObserve default: return runtimehooks.HookSpec{}, fmt.Errorf("kind %q is not supported", item.Kind) } - if err != nil { - return runtimehooks.HookSpec{}, err + if buildErr != nil { + return runtimehooks.HookSpec{}, buildErr } return runtimehooks.HookSpec{ ID: strings.TrimSpace(item.ID), - Point: runtimehooks.HookPoint(strings.TrimSpace(item.Point)), + Point: point, Scope: scope, Source: source, Kind: specKind, @@ -236,6 +243,7 @@ func buildConfiguredHookSpec( Timeout: time.Duration(item.TimeoutSec) * time.Second, FailurePolicy: mapRuntimeHookFailurePolicy(item.FailurePolicy), Handler: handler, + Matcher: matcher, }, nil } @@ -254,12 +262,26 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported", item.Mode) } + handler := strings.ToLower(strings.TrimSpace(item.Handler)) + if handler == "warn_on_tool_call" && !runtimehooks.HasHookMatcherConfig(item.Match) { + return fmt.Errorf("handler %q requires match", item.Handler) + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } case configuredHookKindCommand: if mode != configuredHookModeSync { return fmt.Errorf("mode %q is not supported for kind command (only sync)", item.Mode) } - if strings.TrimSpace(readHookParamString(item.Params, "command")) == "" { - return fmt.Errorf("kind command requires params.command") + if _, _, err := runtimehooks.ParseCommandParams(item.Params); err != nil { + return err + } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } } case configuredHookKindHTTP: if mode != configuredHookModeObserve { @@ -269,6 +291,11 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop if policy == "fail_closed" { return fmt.Errorf("failure_policy %q is not supported for kind http observe", item.FailurePolicy) } + if runtimehooks.HasHookMatcherConfig(item.Match) { + if err := runtimehooks.ValidateHookMatcher(runtimehooks.HookPoint(strings.TrimSpace(item.Point)), item.Match); err != nil { + return fmt.Errorf("match: %w", err) + } + } default: if isExternalHookKind(kind) { return fmt.Errorf( @@ -281,6 +308,18 @@ func validateConfiguredHookItemForP6Lite(item config.RuntimeHookItemConfig, scop return nil } +// buildConfiguredHookMatcher 编译 hook matcher。 +func buildConfiguredHookMatcher(item config.RuntimeHookItemConfig, point runtimehooks.HookPoint) (*runtimehooks.HookMatcher, error) { + if !runtimehooks.HasHookMatcherConfig(item.Match) { + return nil, nil + } + matcher, err := runtimehooks.CompileHookMatcher(point, item.Match) + if err != nil { + return nil, fmt.Errorf("match: %w", err) + } + return matcher, nil +} + func isExternalHookKind(kind string) bool { switch strings.ToLower(strings.TrimSpace(kind)) { case "command", "http", "prompt", "agent": @@ -337,28 +376,14 @@ func buildUserBuiltinHookHandler( return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} }, nil case "warn_on_tool_call": - targetTool := strings.ToLower(strings.TrimSpace(readHookParamString(params, "tool_name"))) - targetTools := normalizeHookParamStringSlice(readHookParamStringSlice(params, "tool_names")) - if targetTool == "" && len(targetTools) == 0 { - return nil, fmt.Errorf("handler warn_on_tool_call requires params.tool_name or params.tool_names") - } defaultMessage := "tool call matched warn_on_tool_call" if customMessage := strings.TrimSpace(readHookParamString(params, "message")); customMessage != "" { defaultMessage = customMessage } return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { _ = ctx - toolName := strings.ToLower(strings.TrimSpace(readHookContextMetadataString(input, "tool_name"))) - if toolName == "" { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} - } - if targetTool != "" && toolName == targetTool { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} - } - if len(targetTools) > 0 && slices.Contains(targetTools, toolName) { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} - } - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass} + _ = input + return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: defaultMessage} }, nil case "add_context_note": note := strings.TrimSpace(readHookParamString(params, "note")) @@ -378,49 +403,24 @@ func buildUserBuiltinHookHandler( } } -// buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,并通过 stdin 传入上下文 JSON。 -func buildUserCommandHookHandler(params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { - command := strings.TrimSpace(readHookParamString(params, "command")) - if command == "" { - return nil, fmt.Errorf("kind command requires params.command") +// buildUserCommandHookHandler 将命令型 hook 转为同步阻断处理器,使用 stdin/stdout JSON 协议。 +func buildUserCommandHookHandler(hookID string, point runtimehooks.HookPoint, params map[string]any, defaultWorkdir string) (runtimehooks.HookHandler, error) { + argv, shell, err := runtimehooks.ParseCommandParams(params) + if err != nil { + return nil, err } return func(ctx context.Context, input runtimehooks.HookContext) runtimehooks.HookResult { - workdir := resolveHookWorkdir(input, defaultWorkdir) - cmd := buildCommandHookProcess(ctx, command) - if strings.TrimSpace(workdir) != "" { - cmd.Dir = workdir - } - payload, err := json.Marshal(input) - if err != nil { - detail := fmt.Sprintf("command hook marshal input failed: %v", err) - return runtimehooks.HookResult{Status: runtimehooks.HookResultFailed, Message: detail, Error: detail} - } - cmd.Stdin = bytes.NewReader(payload) - output, err := cmd.CombinedOutput() - message := strings.TrimSpace(string(output)) - if err == nil { - return runtimehooks.HookResult{Status: runtimehooks.HookResultPass, Message: message} - } - var exitErr *exec.ExitError - if errors.As(err, &exitErr) && (exitErr.ExitCode() == 1 || exitErr.ExitCode() == 2) { - return runtimehooks.HookResult{Status: runtimehooks.HookResultBlock, Message: message} - } - detail := strings.TrimSpace(message) - if detail == "" { - detail = err.Error() - } - return runtimehooks.HookResult{Status: runtimehooks.HookResultFailed, Message: detail, Error: err.Error()} + spec := runtimehooks.CommandHookSpec{ + HookID: hookID, + Point: point, + Command: argv, + Shell: shell, + Workdir: resolveHookWorkdir(input, defaultWorkdir), + } + return runtimehooks.RunCommandHook(ctx, spec, input) }, nil } -// buildCommandHookProcess 以当前平台的 shell 执行用户命令,保留脚本组合能力。 -func buildCommandHookProcess(ctx context.Context, command string) *exec.Cmd { - if runtime.GOOS == "windows" { - return exec.CommandContext(ctx, "powershell", "-Command", command) - } - return exec.CommandContext(ctx, "sh", "-c", command) -} - // buildUserHTTPObserveHookHandler 将 kind=http 的 observe 配置转换为观测回调处理器。 func buildUserHTTPObserveHookHandler(item config.RuntimeHookItemConfig) (runtimehooks.HookHandler, error) { endpoint := strings.TrimSpace(readHookParamString(item.Params, "url")) diff --git a/internal/runtime/user_hooks_test.go b/internal/runtime/user_hooks_test.go index f328687ec..dfe7f48cf 100644 --- a/internal/runtime/user_hooks_test.go +++ b/internal/runtime/user_hooks_test.go @@ -33,9 +33,11 @@ func TestBuildUserHookSpecMapsFailurePolicyAndScope(t *testing.T) { Priority: 99, TimeoutSec: 7, FailurePolicy: "warn_only", - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", - "message": "tool call warning", + }, + Params: map[string]any{ + "message": "tool call warning", }, } @@ -351,16 +353,13 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Parallel() warnHandler, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{ - "tool_name": "bash", - "message": "bash was called", + "message": "bash was called", }, t.TempDir()) if err != nil { t.Fatalf("build warn handler: %v", err) } warnResult := warnHandler(context.Background(), runtimehooks.HookContext{ - Metadata: map[string]any{ - "tool_name": "bash", - }, + Metadata: map[string]any{}, }) if warnResult.Status != runtimehooks.HookResultPass { t.Fatalf("warn status = %q, want pass", warnResult.Status) @@ -369,13 +368,13 @@ func TestWarnOnToolCallAndAddContextNoteHandlers(t *testing.T) { t.Fatalf("warn message = %q, want %q", warnResult.Message, "bash was called") } - ignoreResult := warnHandler(context.Background(), runtimehooks.HookContext{ + anyToolResult := warnHandler(context.Background(), runtimehooks.HookContext{ Metadata: map[string]any{ "tool_name": "filesystem", }, }) - if strings.TrimSpace(ignoreResult.Message) != "" { - t.Fatalf("expected unmatched tool to have empty message, got %q", ignoreResult.Message) + if anyToolResult.Message != "bash was called" { + t.Fatalf("warn message = %q, want %q", anyToolResult.Message, "bash was called") } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{ @@ -408,7 +407,7 @@ func TestConfigureRuntimeHooksFromConfig(t *testing.T) { Kind: "builtin", Mode: "sync", Handler: "warn_on_tool_call", - Params: map[string]any{ + Match: map[string]any{ "tool_name": "bash", }, }, @@ -455,10 +454,12 @@ func TestConfigureRuntimeHooksFromConfigKeepsBaseExecutorAndComposes(t *testing. Scope: "user", Kind: "builtin", Mode: "sync", + Match: map[string]any{ + "tool_name": "bash", + }, Handler: "warn_on_tool_call", Params: map[string]any{ - "tool_name": "bash", - "message": "warn", + "message": "warn", }, }, } @@ -821,7 +822,7 @@ func TestConfigureRuntimeHooksWithoutItemsKeepsBehaviorUnchanged(t *testing.T) { out := service.hookExecutor.Run( context.Background(), runtimehooks.HookPointBeforeToolCall, - runtimehooks.HookContext{Metadata: map[string]any{"tool_name": "bash", "workdir": cfg.Workdir}}, + runtimehooks.HookContext{Metadata: map[string]any{"workdir": cfg.Workdir}}, ) if out.Blocked || len(out.Results) != 0 { t.Fatalf("unexpected hook output without user/repo config: %+v", out) @@ -834,8 +835,12 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()); err == nil { - t.Fatal("expected missing target error") + handlerWithoutTarget, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, t.TempDir()) + if err != nil { + t.Fatalf("build warn_on_tool_call without target error: %v", err) + } + if got := handlerWithoutTarget(context.Background(), runtimehooks.HookContext{}); got.Message == "" { + t.Fatalf("expected default warning message when no target is configured, got %+v", got) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, t.TempDir()); err == nil { t.Fatal("expected missing note/message error") @@ -851,7 +856,7 @@ func TestBuildUserBuiltinHookHandlerEdgeCases(t *testing.T) { t.Fatalf("expected match message, got %q", pass.Message) } noTool := handler(context.Background(), runtimehooks.HookContext{}) - if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "" { + if noTool.Status != runtimehooks.HookResultPass || noTool.Message != "hit" { t.Fatalf("unexpected no-tool result: %+v", noTool) } @@ -1518,8 +1523,8 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { if _, err := buildUserBuiltinHookHandler("require_file_exists", map[string]any{}, workdir); err == nil { t.Fatal("expected missing path error") } - if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err == nil { - t.Fatal("expected missing tool target error") + if _, err := buildUserBuiltinHookHandler("warn_on_tool_call", map[string]any{}, workdir); err != nil { + t.Fatalf("warn_on_tool_call without target should be allowed for matcher-based filtering: %v", err) } if _, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{}, workdir); err == nil { t.Fatal("expected missing note/message error") @@ -1539,8 +1544,8 @@ func TestUserHookHandlersAndPathChecks(t *testing.T) { t.Fatalf("expected default warn message for matched tool") } result = warnHandler(context.Background(), runtimehooks.HookContext{}) - if result.Message != "" { - t.Fatalf("expected empty message when no tool_name metadata, got %q", result.Message) + if result.Message == "" { + t.Fatalf("expected default warn message, got empty") } noteHandler, err := buildUserBuiltinHookHandler("add_context_note", map[string]any{"message": "note-via-message"}, workdir) diff --git a/internal/session/input_preparer.go b/internal/session/input_preparer.go index 76b8dc6df..f5b9d9bab 100644 --- a/internal/session/input_preparer.go +++ b/internal/session/input_preparer.go @@ -20,6 +20,7 @@ const defaultSessionTitle = "New Session" // PrepareImageInput 表示一次用户输入中附带的本地图片引用。 type PrepareImageInput struct { Path string + AssetID string MimeType string } @@ -128,6 +129,32 @@ func (p *InputPreparer) Prepare(ctx context.Context, input PrepareInput) (Prepar savedAssets := make([]AssetMeta, 0, len(input.Images)) for index, image := range input.Images { path := strings.TrimSpace(image.Path) + assetID := strings.TrimSpace(image.AssetID) + if assetID != "" { + if path != "" { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: path, + Err: fmt.Errorf("image input cannot contain both path and asset id"), + } + } + meta, err := p.referenceImageAsset(ctx, session.ID, assetID, image.MimeType) + if err != nil { + p.rollbackCreatedSession(ctx, session.ID, sessionCreated) + p.cleanupSavedAssets(ctx, session.ID, savedAssets) + return PreparedInput{}, &AssetSaveError{ + SessionID: session.ID, + Index: index, + Path: assetID, + Err: err, + } + } + parts = append(parts, providertypes.NewSessionAssetImagePart(meta.ID, meta.MimeType)) + continue + } if path == "" { p.rollbackCreatedSession(ctx, session.ID, sessionCreated) p.cleanupSavedAssets(ctx, session.ID, savedAssets) @@ -220,6 +247,38 @@ func (p *InputPreparer) saveImageAsset( return meta, nil } +// referenceImageAsset 校验已保存附件属于当前会话,并返回可进入 provider 的图片元数据。 +func (p *InputPreparer) referenceImageAsset( + ctx context.Context, + sessionID string, + assetID string, + mimeType string, +) (AssetMeta, error) { + if err := ctx.Err(); err != nil { + return AssetMeta{}, err + } + if p.assetStore == nil { + return AssetMeta{}, fmt.Errorf("session: asset store is not configured") + } + normalizedAssetID := strings.TrimSpace(assetID) + if normalizedAssetID == "" { + return AssetMeta{}, fmt.Errorf("image asset id is empty") + } + + meta, err := p.assetStore.Stat(ctx, sessionID, normalizedAssetID) + if err != nil { + return AssetMeta{}, fmt.Errorf("stat image asset: %w", err) + } + if !strings.HasPrefix(strings.ToLower(strings.TrimSpace(meta.MimeType)), "image/") { + return AssetMeta{}, fmt.Errorf("asset %q is not an image", normalizedAssetID) + } + declaredMime := normalizeMimeType(mimeType) + if declaredMime != "" && declaredMime != meta.MimeType { + return AssetMeta{}, fmt.Errorf("declared mime type %q mismatches saved asset %q", declaredMime, meta.MimeType) + } + return meta, nil +} + // resolveImageMimeType 解析图片 MIME 类型,仅允许 image/*,并要求声明值与文件头探测一致。 func resolveImageMimeType(ctx context.Context, path string, declared string, file *os.File) (string, error) { if err := ctx.Err(); err != nil { diff --git a/internal/session/input_preparer_test.go b/internal/session/input_preparer_test.go index d45527799..356449cc7 100644 --- a/internal/session/input_preparer_test.go +++ b/internal/session/input_preparer_test.go @@ -1,6 +1,7 @@ package session import ( + "bytes" "context" "errors" "io" @@ -94,6 +95,46 @@ func TestInputPreparerPrepareTextAndImage(t *testing.T) { } } +func TestInputPreparerPrepareSavedAssetReference(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := newInputPreparerTestStore(t, workdir) + session := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + meta, err := store.SaveAsset(context.Background(), session.ID, bytes.NewReader(minimalPNGBytes()), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + result, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "describe it", + Images: []PrepareImageInput{{AssetID: meta.ID, MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 0 { + t.Fatalf("expected no newly saved assets, got %+v", result.SavedAssets) + } + if len(result.Parts) != 2 { + t.Fatalf("expected text and image parts, got %+v", result.Parts) + } + imagePart := result.Parts[1] + if imagePart.Kind != providertypes.ContentPartImage || + imagePart.Image == nil || + imagePart.Image.Asset == nil || + imagePart.Image.Asset.ID != meta.ID || + imagePart.Image.Asset.MimeType != "image/png" { + t.Fatalf("unexpected image part: %+v", imagePart) + } +} + func TestInputPreparerPrepareImageInfersMimeWhenMissing(t *testing.T) { t.Parallel() @@ -185,6 +226,51 @@ func TestInputPreparerPrepareErrors(t *testing.T) { } }) + t.Run("missing image reference is rejected", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: " ", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing image reference error") + } + if !strings.Contains(err.Error(), "image path is empty") { + t.Fatalf("expected image reference error, got %v", err) + } + }) + + t.Run("missing referenced asset is rejected", func(t *testing.T) { + localStore := newInputPreparerTestStore(t, workdir) + existing := NewWithWorkdir("asset-missing", workdir) + if err := createSessionForPreparerTest(context.Background(), localStore, existing); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + preparer := NewInputPreparer(localStore, localStore) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + SessionID: existing.ID, + Text: "bad asset", + Images: []PrepareImageInput{{AssetID: "asset-missing", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected missing referenced asset error") + } + }) + + t.Run("asset id and path cannot both be set", func(t *testing.T) { + preparer := NewInputPreparer(store, store) + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "bad asset", + Images: []PrepareImageInput{{Path: "a.png", AssetID: "asset-1", MimeType: "image/png"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected asset id and path conflict error") + } + }) + t.Run("asset save error is structured", func(t *testing.T) { preparer := NewInputPreparer(store, store) _, err := preparer.Prepare(context.Background(), PrepareInput{ @@ -384,6 +470,92 @@ func TestInputPreparerPrepareImagePathAndMimeValidation(t *testing.T) { t.Fatalf("expected mismatch error, got %v", err) } }) + + t.Run("declared mime params are normalized", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-params.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + result, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared params", + Images: []PrepareImageInput{{Path: imagePath, MimeType: " IMAGE/PNG; charset=binary "}}, + DefaultWorkdir: workdir, + }) + if err != nil { + t.Fatalf("Prepare() error = %v", err) + } + if len(result.SavedAssets) != 1 || result.SavedAssets[0].MimeType != "image/png" { + t.Fatalf("unexpected saved assets: %+v", result.SavedAssets) + } + }) + + t.Run("declared non image mime is rejected", func(t *testing.T) { + imagePath := filepath.Join(workdir, "declared-text.png") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "declared text", + Images: []PrepareImageInput{{Path: imagePath, MimeType: "text/plain"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected non-image mime error") + } + if !strings.Contains(err.Error(), "is not an image") { + t.Fatalf("expected non-image mime error, got %v", err) + } + }) + + t.Run("extension mismatch is rejected when mime omitted", func(t *testing.T) { + imagePath := filepath.Join(workdir, "wrong.jpg") + if err := os.WriteFile(imagePath, minimalPNGBytes(), 0o644); err != nil { + t.Fatalf("write image: %v", err) + } + + _, err := preparer.Prepare(context.Background(), PrepareInput{ + Text: "extension mismatch", + Images: []PrepareImageInput{{Path: imagePath}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected extension mismatch error") + } + if !strings.Contains(err.Error(), "file extension mime") { + t.Fatalf("expected extension mismatch error, got %v", err) + } + }) +} + +func TestInputPreparerPrepareSavedAssetReferenceValidation(t *testing.T) { + t.Parallel() + + workdir := t.TempDir() + store := newInputPreparerTestStore(t, workdir) + session := NewWithWorkdir("existing", workdir) + if err := createSessionForPreparerTest(context.Background(), store, session); err != nil { + t.Fatalf("createSessionForPreparerTest() error = %v", err) + } + meta, err := store.SaveAsset(context.Background(), session.ID, bytes.NewReader(minimalPNGBytes()), "image/png") + if err != nil { + t.Fatalf("SaveAsset() error = %v", err) + } + + preparer := NewInputPreparer(store, store) + _, err = preparer.Prepare(context.Background(), PrepareInput{ + SessionID: session.ID, + Text: "bad declared mime", + Images: []PrepareImageInput{{AssetID: meta.ID, MimeType: "image/jpeg"}}, + DefaultWorkdir: workdir, + }) + if err == nil { + t.Fatalf("expected referenced asset mime mismatch") + } + if !strings.Contains(err.Error(), "mismatches saved asset") { + t.Fatalf("expected saved asset mismatch error, got %v", err) + } } func TestAssetSaveErrorMethods(t *testing.T) { diff --git a/internal/tools/ask_user_tool.go b/internal/tools/ask_user_tool.go index 96a20e8bc..d1a537cb1 100644 --- a/internal/tools/ask_user_tool.go +++ b/internal/tools/ask_user_tool.go @@ -130,10 +130,6 @@ func (t *askUserTool) Schema() map[string]any { } } -func (t *askUserTool) MicroCompactPolicy() MicroCompactPolicy { - return MicroCompactPolicyPreserveHistory -} - func (t *askUserTool) Execute(ctx context.Context, call ToolCallInput) (ToolResult, error) { if t.broker == nil { return NewErrorResult(ToolNameAskUser, "ask_user broker not available", "ask_user broker is nil", nil), fmt.Errorf("tools: ask_user broker is nil") diff --git a/internal/tools/ask_user_tool_test.go b/internal/tools/ask_user_tool_test.go index 6f003cc8b..427a28cca 100644 --- a/internal/tools/ask_user_tool_test.go +++ b/internal/tools/ask_user_tool_test.go @@ -37,9 +37,6 @@ func TestNewAskUserToolDefaults(t *testing.T) { if _, ok := schema["properties"]; !ok { t.Fatalf("expected schema with properties") } - if tool.MicroCompactPolicy() != MicroCompactPolicyPreserveHistory { - t.Fatalf("expected PreserveHistory policy, got %v", tool.MicroCompactPolicy()) - } } func TestAskUserToolSchemaHasRequiredFields(t *testing.T) { diff --git a/internal/tools/bash/tool.go b/internal/tools/bash/tool.go index 92cf5c0c9..0590587b7 100644 --- a/internal/tools/bash/tool.go +++ b/internal/tools/bash/tool.go @@ -4,10 +4,9 @@ import ( "context" "encoding/json" "errors" + "neo-code/internal/tools" "strings" "time" - - "neo-code/internal/tools" ) type Tool struct { @@ -80,11 +79,6 @@ func (t *Tool) Schema() map[string]any { } } -// MicroCompactPolicy 声明 bash 工具的历史结果默认参与 micro compact 清理。 -func (t *Tool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var in input if err := json.Unmarshal(call.Arguments, &in); err != nil { diff --git a/internal/tools/codebase/read.go b/internal/tools/codebase/read.go index f76476e13..5e4ed970c 100644 --- a/internal/tools/codebase/read.go +++ b/internal/tools/codebase/read.go @@ -3,10 +3,10 @@ package codebase import ( "context" "encoding/json" + "neo-code/internal/tools" "strings" "neo-code/internal/repository" - "neo-code/internal/tools" ) // ReadTool implements the codebase_read tool. @@ -49,10 +49,6 @@ func (t *ReadTool) Schema() map[string]any { } } -func (t *ReadTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - func (t *ReadTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var in struct { Path string `json:"path"` diff --git a/internal/tools/codebase/read_test.go b/internal/tools/codebase/read_test.go index ba29f03b3..5fefb0ab6 100644 --- a/internal/tools/codebase/read_test.go +++ b/internal/tools/codebase/read_test.go @@ -32,9 +32,6 @@ func TestReadToolMetadata(t *testing.T) { if _, hasPath := props["path"]; !hasPath { t.Fatalf("Schema should have path property") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) - } } func TestReadToolInvalidJSON(t *testing.T) { diff --git a/internal/tools/codebase/searchsymbol.go b/internal/tools/codebase/searchsymbol.go index e5d9286eb..bd992eb9c 100644 --- a/internal/tools/codebase/searchsymbol.go +++ b/internal/tools/codebase/searchsymbol.go @@ -3,10 +3,10 @@ package codebase import ( "context" "encoding/json" + "neo-code/internal/tools" "strings" "neo-code/internal/repository" - "neo-code/internal/tools" ) // SearchSymbolTool implements the codebase_search_symbol tool. @@ -53,10 +53,6 @@ func (t *SearchSymbolTool) Schema() map[string]any { } } -func (t *SearchSymbolTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *SearchSymbolTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var in struct { Symbol string `json:"symbol"` diff --git a/internal/tools/codebase/searchsymbol_test.go b/internal/tools/codebase/searchsymbol_test.go index 33f792be3..cab6f50db 100644 --- a/internal/tools/codebase/searchsymbol_test.go +++ b/internal/tools/codebase/searchsymbol_test.go @@ -32,9 +32,6 @@ func TestSearchSymbolToolMetadata(t *testing.T) { if _, hasSymbol := props["symbol"]; !hasSymbol { t.Fatalf("Schema should have symbol property") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyCompact { - t.Fatalf("MicroCompactPolicy() = %v, want Compact", tool.MicroCompactPolicy()) - } } func TestSearchSymbolToolInvalidJSON(t *testing.T) { diff --git a/internal/tools/codebase/searchtext.go b/internal/tools/codebase/searchtext.go index f5f670c7d..f84b2fe6b 100644 --- a/internal/tools/codebase/searchtext.go +++ b/internal/tools/codebase/searchtext.go @@ -3,10 +3,10 @@ package codebase import ( "context" "encoding/json" + "neo-code/internal/tools" "strings" "neo-code/internal/repository" - "neo-code/internal/tools" ) // SearchTextTool implements the codebase_search_text tool. @@ -53,10 +53,6 @@ func (t *SearchTextTool) Schema() map[string]any { } } -func (t *SearchTextTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *SearchTextTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var in struct { Query string `json:"query"` diff --git a/internal/tools/codebase/searchtext_test.go b/internal/tools/codebase/searchtext_test.go index 06741a6b3..7cf87fe1a 100644 --- a/internal/tools/codebase/searchtext_test.go +++ b/internal/tools/codebase/searchtext_test.go @@ -33,9 +33,6 @@ func TestSearchTextToolMetadata(t *testing.T) { if _, hasQuery := props["query"]; !hasQuery { t.Fatalf("Schema should have query property") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyCompact { - t.Fatalf("MicroCompactPolicy() = %v, want Compact", tool.MicroCompactPolicy()) - } } func TestSearchTextToolInvalidJSON(t *testing.T) { diff --git a/internal/tools/diagnose/tool.go b/internal/tools/diagnose/tool.go index 4566d177f..379c542b6 100644 --- a/internal/tools/diagnose/tool.go +++ b/internal/tools/diagnose/tool.go @@ -5,13 +5,13 @@ import ( "encoding/json" "errors" "fmt" + "neo-code/internal/tools" "regexp" "strconv" "strings" "time" "neo-code/internal/subagent" - "neo-code/internal/tools" ) const ( @@ -83,11 +83,6 @@ func (t *Tool) Schema() map[string]any { } } -// MicroCompactPolicy 保留诊断结果,避免短期压缩时丢失排障上下文。 -func (t *Tool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 校验输入并通过 SpawnSubAgent 能力链路执行真实诊断,失败时静默降级。 func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { if err := ctx.Err(); err != nil { diff --git a/internal/tools/diagnose/tool_test.go b/internal/tools/diagnose/tool_test.go index 4fa53ac47..8fb3e3986 100644 --- a/internal/tools/diagnose/tool_test.go +++ b/internal/tools/diagnose/tool_test.go @@ -20,9 +20,6 @@ func TestToolMetadata(t *testing.T) { if tool.Schema() == nil { t.Fatal("Schema() should not be nil") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %q, want %q", tool.MicroCompactPolicy(), tools.MicroCompactPolicyPreserveHistory) - } } func TestToolExecuteFallbackWhenInvokerUnavailable(t *testing.T) { diff --git a/internal/tools/filesystem/copy_file.go b/internal/tools/filesystem/copy_file.go deleted file mode 100644 index 9d2b8c529..000000000 --- a/internal/tools/filesystem/copy_file.go +++ /dev/null @@ -1,129 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "errors" - "os" - "path/filepath" - "strings" - - "neo-code/internal/tools" -) - -type CopyFileTool struct { - root string -} - -type copyFileInput struct { - SourcePath string `json:"source_path"` - DestinationPath string `json:"destination_path"` - Overwrite bool `json:"overwrite,omitempty"` -} - -func NewCopy(root string) *CopyFileTool { - return &CopyFileTool{root: root} -} - -func (t *CopyFileTool) Name() string { - return copyFileToolName -} - -func (t *CopyFileTool) Description() string { - return "Copy a file inside the workspace. Both paths must resolve inside the workspace." -} - -func (t *CopyFileTool) Schema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "source_path": map[string]any{ - "type": "string", - "description": "Existing file path to copy, relative to workspace root or absolute inside workspace.", - }, - "destination_path": map[string]any{ - "type": "string", - "description": "Destination file path, relative to workspace root or absolute inside workspace.", - }, - "overwrite": map[string]any{ - "type": "boolean", - "description": "When true, replace destination if it already exists. Defaults to false.", - }, - }, - "required": []string{"source_path", "destination_path"}, - } -} - -func (t *CopyFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (t *CopyFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { - var args copyFileInput - if err := json.Unmarshal(input.Arguments, &args); err != nil { - return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err - } - if strings.TrimSpace(args.SourcePath) == "" { - err := errors.New(copyFileToolName + ": source_path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if strings.TrimSpace(args.DestinationPath) == "" { - err := errors.New(copyFileToolName + ": destination_path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := ctx.Err(); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) - if err != nil { - return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err - } - - src, err := resolvePath(base, args.SourcePath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - dst, err := resolvePath(base, args.DestinationPath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - srcInfo, statErr := os.Stat(src) - if statErr != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr - } - if srcInfo.IsDir() { - err := errors.New(copyFileToolName + ": source_path must be a file, not a directory") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if _, err := os.Stat(dst); err == nil { - if !args.Overwrite { - err := errors.New(copyFileToolName + ": destination_path already exists; pass overwrite=true to replace it") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } else if !os.IsNotExist(err) { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := copyFileContents(src, dst, srcInfo.Mode().Perm()); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "source_path": normalizeSlashPath(toRelativePath(base, src)), - "destination_path": normalizeSlashPath(toRelativePath(base, dst)), - "paths": []string{normalizeSlashPath(toRelativePath(base, dst))}, - "bytes": srcInfo.Size(), - "overwrite": args.Overwrite, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil -} diff --git a/internal/tools/filesystem/copy_file_test.go b/internal/tools/filesystem/copy_file_test.go deleted file mode 100644 index 20abe2677..000000000 --- a/internal/tools/filesystem/copy_file_test.go +++ /dev/null @@ -1,210 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "neo-code/internal/tools" -) - -func TestCopyFileTool_DuplicatesContent(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "a.go") - if err := os.WriteFile(src, []byte("package main"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "a.go", - "destination_path": filepath.Join("nested", "b.go"), - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if result.IsError { - t.Fatalf("error result: %s", result.Content) - } - if !result.Facts.WorkspaceWrite { - t.Fatalf("expected WorkspaceWrite=true") - } - srcData, _ := os.ReadFile(src) - if string(srcData) != "package main" { - t.Fatalf("source modified: %q", string(srcData)) - } - dst := filepath.Join(workspace, "nested", "b.go") - dstData, err := os.ReadFile(dst) - if err != nil { - t.Fatalf("read dst: %v", err) - } - if string(dstData) != "package main" { - t.Fatalf("dst content = %q", string(dstData)) - } - paths, ok := result.Metadata["paths"].([]string) - if !ok || len(paths) != 1 { - t.Fatalf("paths metadata = %#v want 1-item slice", result.Metadata["paths"]) - } - if got, _ := result.Metadata["source_path"].(string); got != "a.go" { - t.Fatalf("source_path metadata = %q want a.go", got) - } - if got, _ := result.Metadata["destination_path"].(string); got != "nested/b.go" { - t.Fatalf("destination_path metadata = %q want nested/b.go", got) - } - for _, value := range []string{result.Metadata["source_path"].(string), result.Metadata["destination_path"].(string), paths[0]} { - if filepath.IsAbs(value) || strings.Contains(strings.ToLower(value), strings.ToLower(workspace)) { - t.Fatalf("expected metadata path to stay workspace-relative, got %q", value) - } - } -} - -func TestCopyFileTool_RefusesOverwriteByDefault(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - dst := filepath.Join(workspace, "dst.txt") - if err := os.WriteFile(src, []byte("a"), 0o644); err != nil { - t.Fatalf("seed src: %v", err) - } - if err := os.WriteFile(dst, []byte("b"), 0o644); err != nil { - t.Fatalf("seed dst: %v", err) - } - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "already exists") { - t.Fatalf("expected exists error, got %v", err) - } - if data, _ := os.ReadFile(dst); string(data) != "b" { - t.Fatalf("dst was clobbered: %q", string(data)) - } -} - -func TestCopyFileTool_OverwriteAllowed(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - dst := filepath.Join(workspace, "dst.txt") - if err := os.WriteFile(src, []byte("new"), 0o644); err != nil { - t.Fatalf("seed src: %v", err) - } - if err := os.WriteFile(dst, []byte("old"), 0o644); err != nil { - t.Fatalf("seed dst: %v", err) - } - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - "overwrite": true, - }) - if _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }); err != nil { - t.Fatalf("execute: %v", err) - } - if data, _ := os.ReadFile(dst); string(data) != "new" { - t.Fatalf("dst content = %q want new", string(data)) - } - if data, _ := os.ReadFile(src); string(data) != "new" { - t.Fatalf("src removed unexpectedly: %q", string(data)) - } -} - -func TestCopyFileTool_RejectsTraversal(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - if err := os.WriteFile(src, []byte("x"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": filepath.Join("..", "escape.txt"), - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "escapes workspace") { - t.Fatalf("expected escape error, got %v", err) - } -} - -func TestCopyFileTool_InvalidJSON(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCopy(workspace) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: []byte(`{invalid`), - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected json error") - } - if !result.IsError { - t.Fatalf("expected error result") - } -} - -func TestCopyFileTool_RejectsDirectorySource(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - sourceDir := filepath.Join(workspace, "srcdir") - if err := os.MkdirAll(sourceDir, 0o755); err != nil { - t.Fatalf("seed dir: %v", err) - } - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "srcdir", - "destination_path": "copy.txt", - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "must be a file") { - t.Fatalf("expected directory source error, got %v", err) - } -} - -func TestCopyFileTool_RejectsCanceledContext(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCopy(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - }) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := tool.Execute(ctx, tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { - t.Fatalf("expected canceled error, got %v", err) - } -} diff --git a/internal/tools/filesystem/create_dir.go b/internal/tools/filesystem/create_dir.go deleted file mode 100644 index 22d0f8de8..000000000 --- a/internal/tools/filesystem/create_dir.go +++ /dev/null @@ -1,124 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "errors" - "os" - "strings" - - "neo-code/internal/security" - "neo-code/internal/tools" -) - -type CreateDirTool struct { - root string -} - -type createDirInput struct { - Path string `json:"path"` - Recursive *bool `json:"recursive,omitempty"` -} - -func NewCreateDir(root string) *CreateDirTool { - return &CreateDirTool{root: root} -} - -func (t *CreateDirTool) Name() string { - return createDirToolName -} - -func (t *CreateDirTool) Description() string { - return "Create a directory inside the workspace. Recursive by default." -} - -func (t *CreateDirTool) Schema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Directory path relative to workspace root, or absolute inside the workspace.", - }, - "recursive": map[string]any{ - "type": "boolean", - "description": "When true (default), create parent directories as needed; when false, fail if the parent is missing.", - }, - }, - "required": []string{"path"}, - } -} - -func (t *CreateDirTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (t *CreateDirTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { - var args createDirInput - if err := json.Unmarshal(input.Arguments, &args); err != nil { - return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err - } - if strings.TrimSpace(args.Path) == "" { - err := errors.New(createDirToolName + ": path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := ctx.Err(); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - recursive := true - if args.Recursive != nil { - recursive = *args.Recursive - } - - base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) - if err != nil { - return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err - } - - _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if info, statErr := os.Stat(target); statErr == nil { - if !info.IsDir() { - err := errors.New(createDirToolName + ": path exists and is not a directory") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "path": target, - "created": false, - "noop_write": true, - "recursive": recursive, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil - } else if !os.IsNotExist(statErr) { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr - } - - if recursive { - if err := os.MkdirAll(target, 0o755); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } else { - if err := os.Mkdir(target, 0o755); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } - - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "path": target, - "created": true, - "recursive": recursive, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil -} diff --git a/internal/tools/filesystem/create_dir_test.go b/internal/tools/filesystem/create_dir_test.go deleted file mode 100644 index a2434d98e..000000000 --- a/internal/tools/filesystem/create_dir_test.go +++ /dev/null @@ -1,136 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "neo-code/internal/tools" -) - -func TestCreateDirTool_RecursiveByDefault(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCreateDir(workspace) - args, _ := json.Marshal(map[string]any{ - "path": filepath.Join("a", "b", "c"), - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if result.IsError { - t.Fatalf("error result: %s", result.Content) - } - if !result.Facts.WorkspaceWrite { - t.Fatalf("expected WorkspaceWrite=true") - } - target := filepath.Join(workspace, "a", "b", "c") - if info, err := os.Stat(target); err != nil || !info.IsDir() { - t.Fatalf("dir not created: info=%v err=%v", info, err) - } - if got, _ := result.Metadata["created"].(bool); !got { - t.Fatalf("created metadata = %v want true", result.Metadata["created"]) - } -} - -func TestCreateDirTool_NonRecursiveFailsForMissingParent(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCreateDir(workspace) - args, _ := json.Marshal(map[string]any{ - "path": filepath.Join("missing", "child"), - "recursive": false, - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected error for missing parent") - } -} - -func TestCreateDirTool_ExistingDirReturnsNoop(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - dir := filepath.Join(workspace, "existing") - if err := os.MkdirAll(dir, 0o755); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewCreateDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "existing"}) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if got, _ := result.Metadata["noop_write"].(bool); !got { - t.Fatalf("noop_write metadata = %v", result.Metadata["noop_write"]) - } - if got, _ := result.Metadata["created"].(bool); got { - t.Fatalf("created metadata = %v want false", result.Metadata["created"]) - } -} - -func TestCreateDirTool_RejectsExistingFile(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - target := filepath.Join(workspace, "blocker") - if err := os.WriteFile(target, []byte("file"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewCreateDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "blocker"}) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "not a directory") { - t.Fatalf("expected file-blocking error, got %v", err) - } -} - -func TestCreateDirTool_RejectsTraversal(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCreateDir(workspace) - args, _ := json.Marshal(map[string]any{"path": filepath.Join("..", "escape")}) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "escapes workspace") { - t.Fatalf("expected escape error, got %v", err) - } -} - -func TestCreateDirTool_InvalidJSON(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewCreateDir(workspace) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: []byte(`{invalid`), - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected json error") - } - if !result.IsError { - t.Fatalf("expected error result") - } -} diff --git a/internal/tools/filesystem/delete_file.go b/internal/tools/filesystem/delete_file.go index 32aba86d2..3923249fd 100644 --- a/internal/tools/filesystem/delete_file.go +++ b/internal/tools/filesystem/delete_file.go @@ -4,11 +4,11 @@ import ( "context" "encoding/json" "errors" + "neo-code/internal/tools" "os" "strings" "neo-code/internal/security" - "neo-code/internal/tools" ) type DeleteFileTool struct { @@ -44,10 +44,6 @@ func (t *DeleteFileTool) Schema() map[string]any { } } -func (t *DeleteFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *DeleteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args deleteFileInput if err := json.Unmarshal(input.Arguments, &args); err != nil { @@ -88,7 +84,7 @@ func (t *DeleteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr } if info.IsDir() { - err := errors.New(deleteFileToolName + ": path is a directory; use filesystem_remove_dir") + err := errors.New(deleteFileToolName + ": path is a directory") return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err } diff --git a/internal/tools/filesystem/edit.go b/internal/tools/filesystem/edit.go index 2a07b6db2..0190b665c 100644 --- a/internal/tools/filesystem/edit.go +++ b/internal/tools/filesystem/edit.go @@ -5,11 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "neo-code/internal/tools" "os" "strings" "neo-code/internal/security" - "neo-code/internal/tools" ) type EditTool struct { @@ -55,11 +55,6 @@ func (t *EditTool) Schema() map[string]any { } } -// MicroCompactPolicy 声明编辑工具的历史结果默认参与 micro compact 清理。 -func (t *EditTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *EditTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args editInput if err := json.Unmarshal(input.Arguments, &args); err != nil { diff --git a/internal/tools/filesystem/glob.go b/internal/tools/filesystem/glob.go index 3579ff8d0..534213633 100644 --- a/internal/tools/filesystem/glob.go +++ b/internal/tools/filesystem/glob.go @@ -4,13 +4,12 @@ import ( "context" "encoding/json" "errors" + "neo-code/internal/tools" "os" "path/filepath" "regexp" "sort" "strings" - - "neo-code/internal/tools" ) type GlobTool struct { @@ -61,11 +60,6 @@ func (t *GlobTool) Schema() map[string]any { } } -// MicroCompactPolicy 声明 glob 工具的历史结果默认参与 micro compact 清理。 -func (t *GlobTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *GlobTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args globInput if err := json.Unmarshal(input.Arguments, &args); err != nil { diff --git a/internal/tools/filesystem/grep.go b/internal/tools/filesystem/grep.go index 7bf1deca0..00c828c2f 100644 --- a/internal/tools/filesystem/grep.go +++ b/internal/tools/filesystem/grep.go @@ -5,12 +5,11 @@ import ( "encoding/json" "errors" "fmt" + "neo-code/internal/tools" "os" "path/filepath" "regexp" "strings" - - "neo-code/internal/tools" ) const defaultGrepResultLimit = 200 @@ -60,11 +59,6 @@ func (t *GrepTool) Schema() map[string]any { } } -// MicroCompactPolicy 声明 grep 工具的历史结果默认参与 micro compact 清理。 -func (t *GrepTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *GrepTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args grepInput if err := json.Unmarshal(input.Arguments, &args); err != nil { diff --git a/internal/tools/filesystem/helpers.go b/internal/tools/filesystem/helpers.go index c7b6366be..6b385716b 100644 --- a/internal/tools/filesystem/helpers.go +++ b/internal/tools/filesystem/helpers.go @@ -14,11 +14,7 @@ const ( grepToolName = tools.ToolNameFilesystemGrep globToolName = tools.ToolNameFilesystemGlob editToolName = tools.ToolNameFilesystemEdit - moveFileToolName = tools.ToolNameFilesystemMoveFile - copyFileToolName = tools.ToolNameFilesystemCopyFile deleteFileToolName = tools.ToolNameFilesystemDeleteFile - createDirToolName = tools.ToolNameFilesystemCreateDir - removeDirToolName = tools.ToolNameFilesystemRemoveDir ) func toRelativePath(root string, target string) string { diff --git a/internal/tools/filesystem/helpers_test.go b/internal/tools/filesystem/helpers_test.go index e3c0fdc29..dc6d9ada0 100644 --- a/internal/tools/filesystem/helpers_test.go +++ b/internal/tools/filesystem/helpers_test.go @@ -1,7 +1,6 @@ package filesystem import ( - "errors" "os" "path/filepath" "testing" @@ -62,30 +61,6 @@ func TestSkipDirEntry(t *testing.T) { } } -func TestIsCrossDeviceLinkError(t *testing.T) { - t.Parallel() - - cases := []struct { - name string - err error - want bool - }{ - {name: "nil", err: nil, want: false}, - {name: "other", err: errors.New("permission denied"), want: false}, - {name: "cross-device", err: errors.New("invalid cross-device link"), want: true}, - {name: "exdev", err: errors.New("rename failed: EXDEV"), want: true}, - } - - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - if got := isCrossDeviceLinkError(tc.err); got != tc.want { - t.Fatalf("isCrossDeviceLinkError(%v) = %v, want %v", tc.err, got, tc.want) - } - }) - } -} - func mustCreateDir(t *testing.T, path string) { t.Helper() if err := os.MkdirAll(path, 0o755); err != nil { diff --git a/internal/tools/filesystem/move_file.go b/internal/tools/filesystem/move_file.go deleted file mode 100644 index 8340d4159..000000000 --- a/internal/tools/filesystem/move_file.go +++ /dev/null @@ -1,166 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "errors" - "io" - "os" - "path/filepath" - "strings" - - "neo-code/internal/tools" -) - -type MoveFileTool struct { - root string -} - -type moveFileInput struct { - SourcePath string `json:"source_path"` - DestinationPath string `json:"destination_path"` - Overwrite bool `json:"overwrite,omitempty"` -} - -func NewMove(root string) *MoveFileTool { - return &MoveFileTool{root: root} -} - -func (t *MoveFileTool) Name() string { - return moveFileToolName -} - -func (t *MoveFileTool) Description() string { - return "Move or rename a file inside the workspace. Both paths must resolve inside the workspace." -} - -func (t *MoveFileTool) Schema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "source_path": map[string]any{ - "type": "string", - "description": "Existing file path to move, relative to workspace root or absolute inside workspace.", - }, - "destination_path": map[string]any{ - "type": "string", - "description": "New file path, relative to workspace root or absolute inside workspace.", - }, - "overwrite": map[string]any{ - "type": "boolean", - "description": "When true, replace destination if it already exists. Defaults to false.", - }, - }, - "required": []string{"source_path", "destination_path"}, - } -} - -func (t *MoveFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (t *MoveFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { - var args moveFileInput - if err := json.Unmarshal(input.Arguments, &args); err != nil { - return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err - } - if strings.TrimSpace(args.SourcePath) == "" { - err := errors.New(moveFileToolName + ": source_path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if strings.TrimSpace(args.DestinationPath) == "" { - err := errors.New(moveFileToolName + ": destination_path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := ctx.Err(); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) - if err != nil { - return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err - } - - src, err := resolvePath(base, args.SourcePath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - dst, err := resolvePath(base, args.DestinationPath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - srcInfo, statErr := os.Stat(src) - if statErr != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr - } - if srcInfo.IsDir() { - err := errors.New(moveFileToolName + ": source_path must be a file, not a directory") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if _, err := os.Stat(dst); err == nil { - if !args.Overwrite { - err := errors.New(moveFileToolName + ": destination_path already exists; pass overwrite=true to replace it") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } else if !os.IsNotExist(err) { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if err := os.MkdirAll(filepath.Dir(dst), 0o755); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := os.Rename(src, dst); err != nil { - if !isCrossDeviceLinkError(err) { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if copyErr := copyFileContents(src, dst, srcInfo.Mode().Perm()); copyErr != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), copyErr), "", nil), copyErr - } - if removeErr := os.Remove(src); removeErr != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), removeErr), "", nil), removeErr - } - } - - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "source_path": normalizeSlashPath(toRelativePath(base, src)), - "destination_path": normalizeSlashPath(toRelativePath(base, dst)), - "paths": []string{ - normalizeSlashPath(toRelativePath(base, src)), - normalizeSlashPath(toRelativePath(base, dst)), - }, - "bytes": srcInfo.Size(), - "overwrite": args.Overwrite, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil -} - -func copyFileContents(src, dst string, mode os.FileMode) error { - in, err := os.Open(src) - if err != nil { - return err - } - defer in.Close() - out, err := os.OpenFile(dst, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, mode) - if err != nil { - return err - } - defer out.Close() - if _, err := io.Copy(out, in); err != nil { - return err - } - return out.Sync() -} - -func isCrossDeviceLinkError(err error) bool { - if err == nil { - return false - } - msg := strings.ToLower(err.Error()) - return strings.Contains(msg, "cross-device") || strings.Contains(msg, "exdev") -} diff --git a/internal/tools/filesystem/move_file_test.go b/internal/tools/filesystem/move_file_test.go deleted file mode 100644 index 88965e1d8..000000000 --- a/internal/tools/filesystem/move_file_test.go +++ /dev/null @@ -1,263 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "neo-code/internal/tools" -) - -func TestMoveFileTool_RenamesWithinWorkspace(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "old.go") - if err := os.WriteFile(src, []byte("hello"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewMove(workspace) - - args, _ := json.Marshal(map[string]any{ - "source_path": "old.go", - "destination_path": "renamed.go", - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if result.IsError { - t.Fatalf("unexpected error result: %s", result.Content) - } - if !result.Facts.WorkspaceWrite { - t.Fatalf("expected WorkspaceWrite=true") - } - if _, err := os.Stat(src); !os.IsNotExist(err) { - t.Fatalf("source still exists: err=%v", err) - } - dst := filepath.Join(workspace, "renamed.go") - if data, err := os.ReadFile(dst); err != nil { - t.Fatalf("read dst: %v", err) - } else if string(data) != "hello" { - t.Fatalf("dst content = %q want hello", string(data)) - } - if got, ok := result.Metadata["source_path"].(string); !ok || got != "old.go" { - t.Fatalf("source_path metadata = %v want old.go", got) - } - if got, ok := result.Metadata["destination_path"].(string); !ok || got != "renamed.go" { - t.Fatalf("destination_path metadata = %v want renamed.go", got) - } - paths, ok := result.Metadata["paths"].([]string) - if !ok || len(paths) != 2 { - t.Fatalf("paths metadata = %#v, want 2-item slice", result.Metadata["paths"]) - } - for _, value := range []string{result.Metadata["source_path"].(string), result.Metadata["destination_path"].(string), paths[0], paths[1]} { - if filepath.IsAbs(value) || strings.Contains(strings.ToLower(value), strings.ToLower(workspace)) { - t.Fatalf("expected metadata path to stay workspace-relative, got %q", value) - } - } -} - -func TestMoveFileTool_RejectsExistingDestinationWithoutOverwrite(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - dst := filepath.Join(workspace, "dst.txt") - if err := os.WriteFile(src, []byte("a"), 0o644); err != nil { - t.Fatalf("seed src: %v", err) - } - if err := os.WriteFile(dst, []byte("b"), 0o644); err != nil { - t.Fatalf("seed dst: %v", err) - } - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "already exists") { - t.Fatalf("expected exists error, got %v", err) - } - if data, _ := os.ReadFile(dst); string(data) != "b" { - t.Fatalf("dst content modified, got %q want b", string(data)) - } -} - -func TestMoveFileTool_OverwritesWhenAllowed(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - dst := filepath.Join(workspace, "dst.txt") - if err := os.WriteFile(src, []byte("new"), 0o644); err != nil { - t.Fatalf("seed src: %v", err) - } - if err := os.WriteFile(dst, []byte("old"), 0o644); err != nil { - t.Fatalf("seed dst: %v", err) - } - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - "overwrite": true, - }) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if result.IsError { - t.Fatalf("error result: %s", result.Content) - } - if data, _ := os.ReadFile(dst); string(data) != "new" { - t.Fatalf("dst content = %q want new", string(data)) - } -} - -func TestMoveFileTool_RejectsTraversal(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - src := filepath.Join(workspace, "src.txt") - if err := os.WriteFile(src, []byte("x"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": filepath.Join("..", "escape.txt"), - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "escapes workspace") { - t.Fatalf("expected escape error, got %v", err) - } -} - -func TestMoveFileTool_RejectsMissingSource(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "missing.txt", - "destination_path": "out.txt", - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected error for missing source") - } -} - -func TestMoveFileTool_RejectsEmptyPaths(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewMove(workspace) - - for _, tc := range []struct { - name string - args map[string]any - want string - }{ - { - name: "empty source", - args: map[string]any{"source_path": "", "destination_path": "x.txt"}, - want: "source_path is required", - }, - { - name: "empty destination", - args: map[string]any{"source_path": "x.txt", "destination_path": ""}, - want: "destination_path is required", - }, - } { - tc := tc - t.Run(tc.name, func(t *testing.T) { - args, _ := json.Marshal(tc.args) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), tc.want) { - t.Fatalf("expected %q, got %v", tc.want, err) - } - }) - } -} - -func TestMoveFileTool_InvalidJSON(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewMove(workspace) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: []byte(`{invalid`), - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected json error") - } - if !result.IsError { - t.Fatalf("expected error result") - } -} - -func TestMoveFileTool_RejectsDirectorySource(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - sourceDir := filepath.Join(workspace, "srcdir") - if err := os.MkdirAll(sourceDir, 0o755); err != nil { - t.Fatalf("seed dir: %v", err) - } - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "srcdir", - "destination_path": "moved.txt", - }) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "must be a file") { - t.Fatalf("expected directory source error, got %v", err) - } -} - -func TestMoveFileTool_RejectsCanceledContext(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewMove(workspace) - args, _ := json.Marshal(map[string]any{ - "source_path": "src.txt", - "destination_path": "dst.txt", - }) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := tool.Execute(ctx, tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), context.Canceled.Error()) { - t.Fatalf("expected canceled error, got %v", err) - } -} diff --git a/internal/tools/filesystem/read_file.go b/internal/tools/filesystem/read_file.go index c789fdd2b..51bc8fce8 100644 --- a/internal/tools/filesystem/read_file.go +++ b/internal/tools/filesystem/read_file.go @@ -5,12 +5,12 @@ import ( "encoding/json" "errors" "fmt" + "neo-code/internal/tools" "os" "path/filepath" "strings" "neo-code/internal/security" - "neo-code/internal/tools" ) const emitChunkSize = 4 * 1024 @@ -63,11 +63,6 @@ func (t *ReadFileTool) Schema() map[string]any { } } -// MicroCompactPolicy 声明读文件工具的历史结果默认参与 micro compact 清理。 -func (t *ReadFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *ReadFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args readFileInput if err := json.Unmarshal(input.Arguments, &args); err != nil { diff --git a/internal/tools/filesystem/remove_dir.go b/internal/tools/filesystem/remove_dir.go deleted file mode 100644 index ae0be3bb1..000000000 --- a/internal/tools/filesystem/remove_dir.go +++ /dev/null @@ -1,121 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "errors" - "os" - "strings" - - "neo-code/internal/security" - "neo-code/internal/tools" -) - -type RemoveDirTool struct { - root string -} - -type removeDirInput struct { - Path string `json:"path"` - Force bool `json:"force,omitempty"` -} - -func NewRemoveDir(root string) *RemoveDirTool { - return &RemoveDirTool{root: root} -} - -func (t *RemoveDirTool) Name() string { - return removeDirToolName -} - -func (t *RemoveDirTool) Description() string { - return "Remove a directory inside the workspace. By default only empty directories are removed; pass force=true to remove recursively." -} - -func (t *RemoveDirTool) Schema() map[string]any { - return map[string]any{ - "type": "object", - "properties": map[string]any{ - "path": map[string]any{ - "type": "string", - "description": "Directory path relative to workspace root, or absolute inside the workspace.", - }, - "force": map[string]any{ - "type": "boolean", - "description": "When true, remove directory and all contents recursively. Defaults to false (empty directory only).", - }, - }, - "required": []string{"path"}, - } -} - -func (t *RemoveDirTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - -func (t *RemoveDirTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { - var args removeDirInput - if err := json.Unmarshal(input.Arguments, &args); err != nil { - return tools.NewErrorResult(t.Name(), "invalid arguments", err.Error(), nil), err - } - if strings.TrimSpace(args.Path) == "" { - err := errors.New(removeDirToolName + ": path is required") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - if err := ctx.Err(); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - base, err := tools.ResolveEffectiveRoot(t.root, input.Workdir) - if err != nil { - return tools.NewErrorResult(t.Name(), "invalid workdir", err.Error(), nil), err - } - - _, target, err := tools.ResolveWorkspaceTarget(input, security.TargetTypePath, base, args.Path, resolvePath) - if err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - info, statErr := os.Stat(target) - if statErr != nil { - if os.IsNotExist(statErr) { - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "path": target, - "removed": false, - "noop_write": true, - "force": args.Force, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil - } - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), statErr), "", nil), statErr - } - if !info.IsDir() { - err := errors.New(removeDirToolName + ": path is not a directory; use filesystem_delete_file") - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - - if args.Force { - if err := os.RemoveAll(target); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } else { - if err := os.Remove(target); err != nil { - return tools.NewErrorResult(t.Name(), tools.NormalizeErrorReason(t.Name(), err), "", nil), err - } - } - - return tools.ToolResult{ - Name: t.Name(), - Content: "ok", - Metadata: map[string]any{ - "path": target, - "removed": true, - "force": args.Force, - }, - Facts: tools.ToolExecutionFacts{WorkspaceWrite: true}, - }, nil -} diff --git a/internal/tools/filesystem/remove_dir_test.go b/internal/tools/filesystem/remove_dir_test.go deleted file mode 100644 index ae9b69ab4..000000000 --- a/internal/tools/filesystem/remove_dir_test.go +++ /dev/null @@ -1,166 +0,0 @@ -package filesystem - -import ( - "context" - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" - - "neo-code/internal/tools" -) - -func TestRemoveDirTool_RemovesEmptyDir(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - dir := filepath.Join(workspace, "empty") - if err := os.MkdirAll(dir, 0o755); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "empty"}) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if result.IsError { - t.Fatalf("error result: %s", result.Content) - } - if !result.Facts.WorkspaceWrite { - t.Fatalf("expected WorkspaceWrite=true") - } - if _, err := os.Stat(dir); !os.IsNotExist(err) { - t.Fatalf("dir still exists: %v", err) - } -} - -func TestRemoveDirTool_RefusesNonEmptyWithoutForce(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - dir := filepath.Join(workspace, "full") - child := filepath.Join(dir, "x.txt") - if err := os.MkdirAll(dir, 0o755); err != nil { - t.Fatalf("seed dir: %v", err) - } - if err := os.WriteFile(child, []byte("x"), 0o644); err != nil { - t.Fatalf("seed file: %v", err) - } - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "full"}) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected error for non-empty directory") - } - if _, err := os.Stat(child); err != nil { - t.Fatalf("child file destroyed: %v", err) - } -} - -func TestRemoveDirTool_ForceRemovesRecursive(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - dir := filepath.Join(workspace, "tree") - nested := filepath.Join(dir, "a", "b") - if err := os.MkdirAll(nested, 0o755); err != nil { - t.Fatalf("seed: %v", err) - } - if err := os.WriteFile(filepath.Join(nested, "c.txt"), []byte("c"), 0o644); err != nil { - t.Fatalf("seed file: %v", err) - } - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{ - "path": "tree", - "force": true, - }) - if _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }); err != nil { - t.Fatalf("execute: %v", err) - } - if _, err := os.Stat(dir); !os.IsNotExist(err) { - t.Fatalf("dir still exists: %v", err) - } -} - -func TestRemoveDirTool_MissingDirReturnsNoop(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "phantom"}) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err != nil { - t.Fatalf("execute: %v", err) - } - if got, _ := result.Metadata["noop_write"].(bool); !got { - t.Fatalf("noop_write = %v", result.Metadata["noop_write"]) - } - if got, _ := result.Metadata["removed"].(bool); got { - t.Fatalf("removed = %v want false", result.Metadata["removed"]) - } -} - -func TestRemoveDirTool_RejectsFile(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - target := filepath.Join(workspace, "afile.txt") - if err := os.WriteFile(target, []byte("x"), 0o644); err != nil { - t.Fatalf("seed: %v", err) - } - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{"path": "afile.txt"}) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "not a directory") { - t.Fatalf("expected directory-required error, got %v", err) - } -} - -func TestRemoveDirTool_RejectsTraversal(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewRemoveDir(workspace) - args, _ := json.Marshal(map[string]any{"path": filepath.Join("..", "escape")}) - _, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: args, - Workdir: workspace, - }) - if err == nil || !strings.Contains(err.Error(), "escapes workspace") { - t.Fatalf("expected escape error, got %v", err) - } -} - -func TestRemoveDirTool_InvalidJSON(t *testing.T) { - t.Parallel() - workspace := t.TempDir() - tool := NewRemoveDir(workspace) - result, err := tool.Execute(context.Background(), tools.ToolCallInput{ - Name: tool.Name(), - Arguments: []byte(`{invalid`), - Workdir: workspace, - }) - if err == nil { - t.Fatalf("expected json error") - } - if !result.IsError { - t.Fatalf("expected error result") - } -} diff --git a/internal/tools/filesystem/tool_metadata_test.go b/internal/tools/filesystem/tool_metadata_test.go index 27a6c8732..76f0bfe84 100644 --- a/internal/tools/filesystem/tool_metadata_test.go +++ b/internal/tools/filesystem/tool_metadata_test.go @@ -1,10 +1,7 @@ package filesystem import ( - "errors" "testing" - - "neo-code/internal/tools" ) func TestFilesystemToolMetadata(t *testing.T) { @@ -15,42 +12,12 @@ func TestFilesystemToolMetadata(t *testing.T) { toolName string description string schema map[string]any - policy tools.MicroCompactPolicy }{ - { - name: "copy", - toolName: NewCopy("/workspace").Name(), - description: NewCopy("/workspace").Description(), - schema: NewCopy("/workspace").Schema(), - policy: NewCopy("/workspace").MicroCompactPolicy(), - }, - { - name: "move", - toolName: NewMove("/workspace").Name(), - description: NewMove("/workspace").Description(), - schema: NewMove("/workspace").Schema(), - policy: NewMove("/workspace").MicroCompactPolicy(), - }, - { - name: "create dir", - toolName: NewCreateDir("/workspace").Name(), - description: NewCreateDir("/workspace").Description(), - schema: NewCreateDir("/workspace").Schema(), - policy: NewCreateDir("/workspace").MicroCompactPolicy(), - }, { name: "delete file", toolName: NewDelete("/workspace").Name(), description: NewDelete("/workspace").Description(), schema: NewDelete("/workspace").Schema(), - policy: NewDelete("/workspace").MicroCompactPolicy(), - }, - { - name: "remove dir", - toolName: NewRemoveDir("/workspace").Name(), - description: NewRemoveDir("/workspace").Description(), - schema: NewRemoveDir("/workspace").Schema(), - policy: NewRemoveDir("/workspace").MicroCompactPolicy(), }, } @@ -71,26 +38,6 @@ func TestFilesystemToolMetadata(t *testing.T) { if !ok || len(required) == 0 { t.Fatalf("required schema fields missing: %#v", tt.schema["required"]) } - if tt.policy != tools.MicroCompactPolicyCompact { - t.Fatalf("policy = %q, want compact", tt.policy) - } }) } } - -func TestMoveCrossDeviceHelper(t *testing.T) { - t.Parallel() - - if !isCrossDeviceLinkError(errors.New("rename failed: cross-device link")) { - t.Fatal("cross-device error should be detected") - } - if !isCrossDeviceLinkError(errors.New("EXDEV: invalid cross-device link")) { - t.Fatal("EXDEV error should be detected") - } - if isCrossDeviceLinkError(errors.New("permission denied")) { - t.Fatal("unrelated error should not be detected as cross-device") - } - if isCrossDeviceLinkError(nil) { - t.Fatal("nil error should not be detected as cross-device") - } -} diff --git a/internal/tools/filesystem/write_file.go b/internal/tools/filesystem/write_file.go index 1185fb259..a6bca9871 100644 --- a/internal/tools/filesystem/write_file.go +++ b/internal/tools/filesystem/write_file.go @@ -4,13 +4,13 @@ import ( "context" "encoding/json" "errors" + "neo-code/internal/tools" "os" "path/filepath" "strings" "unicode/utf8" "neo-code/internal/security" - "neo-code/internal/tools" ) type WriteFileTool struct { @@ -61,11 +61,6 @@ func (t *WriteFileTool) Schema() map[string]any { } } -// MicroCompactPolicy 声明写文件工具的历史结果默认参与 micro compact 清理。 -func (t *WriteFileTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *WriteFileTool) Execute(ctx context.Context, input tools.ToolCallInput) (tools.ToolResult, error) { var args writeFileInput if err := json.Unmarshal(input.Arguments, &args); err != nil { diff --git a/internal/tools/format_test.go b/internal/tools/format_test.go index d5433ae89..fb21c67ec 100644 --- a/internal/tools/format_test.go +++ b/internal/tools/format_test.go @@ -2,7 +2,6 @@ package tools import ( "errors" - "path/filepath" "strings" "testing" @@ -251,30 +250,6 @@ func TestSanitizeToolMetadata(t *testing.T) { } }, }, - { - name: "keeps relative copy and move path metadata but drops path arrays", - tool: "filesystem_move_file", - input: map[string]any{ - "source_path": "package.json", - "destination_path": "pkg.json", - "paths": []string{"/repo/package.json", "/repo/pkg.json"}, - }, - assert: func(t *testing.T, got map[string]string) { - t.Helper() - if got["source_path"] != "package.json" { - t.Fatalf("expected source_path to be preserved, got %#v", got) - } - if got["destination_path"] != "pkg.json" { - t.Fatalf("expected destination_path to be preserved, got %#v", got) - } - if filepath.IsAbs(got["source_path"]) || filepath.IsAbs(got["destination_path"]) { - t.Fatalf("expected projected copy/move paths to be relative, got %#v", got) - } - if got["paths"] != "" { - t.Fatalf("expected array metadata to be dropped, got %#v", got) - } - }, - }, } for _, tt := range tests { diff --git a/internal/tools/manager.go b/internal/tools/manager.go index 40764d90e..11ae852ec 100644 --- a/internal/tools/manager.go +++ b/internal/tools/manager.go @@ -27,8 +27,6 @@ type SpecListInput struct { // Manager is the runtime-facing tool execution and schema exposure boundary. type Manager interface { ListAvailableSpecs(ctx context.Context, input SpecListInput) ([]providertypes.ToolSpec, error) - MicroCompactPolicy(name string) MicroCompactPolicy - MicroCompactSummarizer(name string) ContentSummarizer // Execute 必须支持并发调用;runtime 可能在同一轮中并行调度多个工具调用。 Execute(ctx context.Context, input ToolCallInput) (ToolResult, error) RememberSessionDecision(sessionID string, action security.Action, scope SessionPermissionScope) error @@ -42,11 +40,9 @@ type Executor interface { } type microCompactPolicyExecutor interface { - MicroCompactPolicy(name string) MicroCompactPolicy } type microCompactSummarizerExecutor interface { - MicroCompactSummarizer(name string) ContentSummarizer } // factsEnrichingExecutor 包装底层执行器,在不信任外部 metadata 的前提下补齐受信结构化事实。 @@ -72,22 +68,6 @@ func (e *factsEnrichingExecutor) Supports(name string) bool { return e.inner.Supports(name) } -// MicroCompactPolicy 透传被包装执行器的压缩策略,确保 UI/Runtime 行为与原实现一致。 -func (e *factsEnrichingExecutor) MicroCompactPolicy(name string) MicroCompactPolicy { - if source, ok := e.inner.(microCompactPolicyExecutor); ok { - return source.MicroCompactPolicy(name) - } - return MicroCompactPolicyCompact -} - -// MicroCompactSummarizer 透传被包装执行器的摘要器实现,避免包装层吞掉摘要能力。 -func (e *factsEnrichingExecutor) MicroCompactSummarizer(name string) ContentSummarizer { - if source, ok := e.inner.(microCompactSummarizerExecutor); ok { - return source.MicroCompactSummarizer(name) - } - return nil -} - // Execute 在执行后按本地权限动作补齐可信 facts,避免运行时依赖远端 metadata。 func (e *factsEnrichingExecutor) Execute(ctx context.Context, input ToolCallInput) (ToolResult, error) { result, err := e.inner.Execute(ctx, input) @@ -324,28 +304,6 @@ func (m *DefaultManager) ListAvailableSpecs(ctx context.Context, input SpecListI return readOnlyFiltered, nil } -// MicroCompactPolicy 返回工具的 micro compact 策略;无法判断时按默认可压缩处理。 -func (m *DefaultManager) MicroCompactPolicy(name string) MicroCompactPolicy { - if m == nil || m.executor == nil { - return MicroCompactPolicyCompact - } - if source, ok := m.executor.(microCompactPolicyExecutor); ok { - return source.MicroCompactPolicy(name) - } - return MicroCompactPolicyCompact -} - -// MicroCompactSummarizer 返回工具的内容摘要器;未注册时返回 nil。 -func (m *DefaultManager) MicroCompactSummarizer(name string) ContentSummarizer { - if m == nil || m.executor == nil { - return nil - } - if source, ok := m.executor.(microCompactSummarizerExecutor); ok { - return source.MicroCompactSummarizer(name) - } - return nil -} - // Execute runs the tool if the permission engine allows it and the sandbox // check passes. func (m *DefaultManager) Execute(ctx context.Context, input ToolCallInput) (ToolResult, error) { diff --git a/internal/tools/manager_test.go b/internal/tools/manager_test.go index cc83dc9bc..b9850a5f4 100644 --- a/internal/tools/manager_test.go +++ b/internal/tools/manager_test.go @@ -11,7 +11,6 @@ import ( "testing" "time" - providertypes "neo-code/internal/provider/types" "neo-code/internal/security" "neo-code/internal/tools/mcp" ) @@ -20,7 +19,6 @@ type managerStubTool struct { name string content string err error - policy MicroCompactPolicy callCount int lastCall ToolCallInput } @@ -31,8 +29,6 @@ func (t *managerStubTool) Description() string { return "stub tool" } func (t *managerStubTool) Schema() map[string]any { return map[string]any{"type": "object"} } -func (t *managerStubTool) MicroCompactPolicy() MicroCompactPolicy { return t.policy } - func (t *managerStubTool) Execute(ctx context.Context, call ToolCallInput) (ToolResult, error) { t.callCount++ t.lastCall = call @@ -49,21 +45,6 @@ type stubSandbox struct { lastAction security.Action } -type executorWithoutOptionalCompactFeatures struct{} - -func (executorWithoutOptionalCompactFeatures) ListAvailableSpecs(ctx context.Context, input SpecListInput) ([]providertypes.ToolSpec, error) { - if err := ctx.Err(); err != nil { - return nil, err - } - return nil, nil -} - -func (executorWithoutOptionalCompactFeatures) Execute(ctx context.Context, call ToolCallInput) (ToolResult, error) { - return ToolResult{}, ctx.Err() -} - -func (executorWithoutOptionalCompactFeatures) Supports(name string) bool { return false } - func (s *stubSandbox) Check(ctx context.Context, action security.Action) (*security.WorkspaceExecutionPlan, error) { s.callCount++ s.lastAction = action @@ -135,93 +116,6 @@ func TestDefaultManagerListAvailableSpecsReadOnlyFiltersWriteTools(t *testing.T) } } -func TestDefaultManagerMicroCompactPolicy(t *testing.T) { - t.Parallel() - - t.Run("nil manager defaults to compact", func(t *testing.T) { - t.Parallel() - - var manager *DefaultManager - if got := manager.MicroCompactPolicy("custom_tool"); got != MicroCompactPolicyCompact { - t.Fatalf("expected compact default, got %q", got) - } - }) - - t.Run("executor without policy support defaults to compact", func(t *testing.T) { - t.Parallel() - - manager, err := NewManager(executorWithoutOptionalCompactFeatures{}, mustAllowEngine(t), nil) - if err != nil { - t.Fatalf("new manager: %v", err) - } - if got := manager.MicroCompactPolicy("custom_tool"); got != MicroCompactPolicyCompact { - t.Fatalf("expected compact default, got %q", got) - } - }) - - t.Run("executor policy is forwarded", func(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - registry.Register(&managerStubTool{name: "preserve_tool", policy: MicroCompactPolicyPreserveHistory}) - - manager, err := NewManager(registry, mustAllowEngine(t), nil) - if err != nil { - t.Fatalf("new manager: %v", err) - } - if got := manager.MicroCompactPolicy("preserve_tool"); got != MicroCompactPolicyPreserveHistory { - t.Fatalf("expected preserve history, got %q", got) - } - }) -} - -func TestDefaultManagerMicroCompactSummarizer(t *testing.T) { - t.Parallel() - - t.Run("nil manager returns nil", func(t *testing.T) { - t.Parallel() - - var manager *DefaultManager - if got := manager.MicroCompactSummarizer("custom_tool"); got != nil { - t.Fatalf("expected nil summarizer, got non-nil") - } - }) - - t.Run("executor without summarizer support returns nil", func(t *testing.T) { - t.Parallel() - - manager, err := NewManager(executorWithoutOptionalCompactFeatures{}, mustAllowEngine(t), nil) - if err != nil { - t.Fatalf("new manager: %v", err) - } - if got := manager.MicroCompactSummarizer("custom_tool"); got != nil { - t.Fatalf("expected nil summarizer, got non-nil") - } - }) - - t.Run("executor summarizer is forwarded", func(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - registry.RegisterSummarizer("custom_tool", func(content string, metadata map[string]string, isError bool) string { - return "summary:" + content - }) - - manager, err := NewManager(registry, mustAllowEngine(t), nil) - if err != nil { - t.Fatalf("new manager: %v", err) - } - - summarizer := manager.MicroCompactSummarizer("CUSTOM_TOOL") - if summarizer == nil { - t.Fatal("expected non-nil summarizer") - } - if got := summarizer("content", nil, false); got != "summary:content" { - t.Fatalf("unexpected summary output: %q", got) - } - }) -} - func TestDefaultManagerListAvailableSpecsBoundaries(t *testing.T) { t.Parallel() diff --git a/internal/tools/memo/list.go b/internal/tools/memo/list.go index 72b2676eb..5857df60f 100644 --- a/internal/tools/memo/list.go +++ b/internal/tools/memo/list.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" + "neo-code/internal/tools" "strings" "neo-code/internal/memo" - "neo-code/internal/tools" ) const listToolName = tools.ToolNameMemoList @@ -44,11 +44,6 @@ func (t *ListTool) Schema() map[string]any { } } -// MicroCompactPolicy 记忆目录结果应保留在上下文中,不参与 micro compact 清理。 -func (t *ListTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 执行 memo_list 工具调用。 func (t *ListTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { if t.svc == nil { diff --git a/internal/tools/memo/list_test.go b/internal/tools/memo/list_test.go index 7062039fd..60ab39c1f 100644 --- a/internal/tools/memo/list_test.go +++ b/internal/tools/memo/list_test.go @@ -21,9 +21,6 @@ func TestListToolName(t *testing.T) { if tool.Schema() == nil { t.Fatal("Schema() should not be nil") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) - } } func TestListToolExecuteEmpty(t *testing.T) { diff --git a/internal/tools/memo/recall.go b/internal/tools/memo/recall.go index dfcc2cdff..16d84226e 100644 --- a/internal/tools/memo/recall.go +++ b/internal/tools/memo/recall.go @@ -4,11 +4,11 @@ import ( "context" "encoding/json" "fmt" + "neo-code/internal/tools" "sort" "strings" "neo-code/internal/memo" - "neo-code/internal/tools" ) const ( @@ -56,11 +56,6 @@ func (t *RecallTool) Schema() map[string]any { } } -// MicroCompactPolicy 记忆读取结果应保留在上下文中,不参与 micro compact 清理。 -func (t *RecallTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 执行 memo_recall 工具调用。调用前须确保 svc 已通过构造函数注入。 func (t *RecallTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var args recallInput diff --git a/internal/tools/memo/recall_test.go b/internal/tools/memo/recall_test.go index 04cbf1c7e..93524e37c 100644 --- a/internal/tools/memo/recall_test.go +++ b/internal/tools/memo/recall_test.go @@ -87,9 +87,6 @@ func TestRecallToolDescriptionAndSchema(t *testing.T) { if schema == nil { t.Fatal("Schema() should not be nil") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) - } } func TestRecallToolExecuteWithScopeFilter(t *testing.T) { diff --git a/internal/tools/memo/remember.go b/internal/tools/memo/remember.go index 5009ef012..9de35fc61 100644 --- a/internal/tools/memo/remember.go +++ b/internal/tools/memo/remember.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" + "neo-code/internal/tools" "strings" "neo-code/internal/memo" - "neo-code/internal/tools" ) const ( @@ -69,11 +69,6 @@ func (t *RememberTool) Schema() map[string]any { } } -// MicroCompactPolicy 记忆写入结果应保留在上下文中,不参与 micro compact 清理。 -func (t *RememberTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 执行 memo_remember 工具调用。调用前须确保 svc 已通过构造函数注入。 func (t *RememberTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var args rememberInput diff --git a/internal/tools/memo/remove.go b/internal/tools/memo/remove.go index 6f5e0c811..0f11fc0bc 100644 --- a/internal/tools/memo/remove.go +++ b/internal/tools/memo/remove.go @@ -4,10 +4,10 @@ import ( "context" "encoding/json" "fmt" + "neo-code/internal/tools" "strings" "neo-code/internal/memo" - "neo-code/internal/tools" ) const removeToolName = tools.ToolNameMemoRemove @@ -50,11 +50,6 @@ func (t *RemoveTool) Schema() map[string]any { } } -// MicroCompactPolicy 删除结果应保留在上下文中,不参与 micro compact 清理。 -func (t *RemoveTool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 执行 memo_remove 工具调用。 func (t *RemoveTool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { var args removeInput diff --git a/internal/tools/memo/remove_test.go b/internal/tools/memo/remove_test.go index bcf5d9a8a..4b96f3a39 100644 --- a/internal/tools/memo/remove_test.go +++ b/internal/tools/memo/remove_test.go @@ -21,9 +21,6 @@ func TestRemoveToolName(t *testing.T) { if tool.Schema() == nil { t.Fatal("Schema() should not be nil") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %v, want PreserveHistory", tool.MicroCompactPolicy()) - } } func TestRemoveToolExecuteSuccess(t *testing.T) { diff --git a/internal/tools/micro_compact_policy.go b/internal/tools/micro_compact_policy.go deleted file mode 100644 index 9225e9bcc..000000000 --- a/internal/tools/micro_compact_policy.go +++ /dev/null @@ -1,11 +0,0 @@ -package tools - -// MicroCompactPolicy 描述工具历史结果参与 read-time micro compact 的策略。 -type MicroCompactPolicy string - -const ( - // MicroCompactPolicyCompact 表示工具历史结果默认参与 micro compact 清理。 - MicroCompactPolicyCompact MicroCompactPolicy = "" - // MicroCompactPolicyPreserveHistory 表示工具历史结果应显式保留,不参与 micro compact 清理。 - MicroCompactPolicyPreserveHistory MicroCompactPolicy = "preserve_history" -) diff --git a/internal/tools/micro_compact_summarizer.go b/internal/tools/micro_compact_summarizer.go deleted file mode 100644 index ad5a8274e..000000000 --- a/internal/tools/micro_compact_summarizer.go +++ /dev/null @@ -1,6 +0,0 @@ -package tools - -// ContentSummarizer 将工具结果内容压缩为短摘要,用于 micro-compact 替换旧工具输出。 -// content 和 metadata 来自持久化后的 Message 字段,isError 标识原始工具是否报错。 -// 返回空字符串表示"无摘要,回退到默认清除行为"。 -type ContentSummarizer func(content string, metadata map[string]string, isError bool) string diff --git a/internal/tools/micro_compact_summarizer_test.go b/internal/tools/micro_compact_summarizer_test.go deleted file mode 100644 index bf3cee474..000000000 --- a/internal/tools/micro_compact_summarizer_test.go +++ /dev/null @@ -1,460 +0,0 @@ -package tools - -import ( - "strings" - "sync" - "testing" - "unicode/utf8" -) - -// stubMetadata 快速构建测试用 metadata map。 -func stubMetadata(keyValue ...string) map[string]string { - m := make(map[string]string, len(keyValue)/2) - for i := 0; i+1 < len(keyValue); i += 2 { - m[keyValue[i]] = keyValue[i+1] - } - return m -} - -func assertContains(t *testing.T, got, expected string) { - t.Helper() - if !strings.Contains(got, expected) { - t.Fatalf("expected %q in summary, got %q", expected, got) - } -} - -func assertMaxRuneCount(t *testing.T, got string, max int) { - t.Helper() - if utf8.RuneCountInString(got) > max { - t.Fatalf("summary exceeds %d runes: %d", max, utf8.RuneCountInString(got)) - } -} - -func assertEmptySummary(t *testing.T, got string) { - t.Helper() - if got != "" { - t.Fatalf("expected empty string, got %q", got) - } -} - -func TestBashSummarizer(t *testing.T) { - t.Parallel() - - t.Run("normal_output", func(t *testing.T) { - content := "line1\nline2\nline3\nline4\nline5\nline6\nline7\nline8" - meta := stubMetadata("workdir", "/home/user/project") - got := bashSummarizer(content, meta, false) - assertContains(t, got, "[exit=0]") - assertContains(t, got, "workdir=/home/user/project") - assertContains(t, got, "lines=8") - assertContains(t, got, "chars=") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("error_output", func(t *testing.T) { - content := "error: command not found" - meta := stubMetadata("workdir", "/tmp") - got := bashSummarizer(content, meta, true) - assertContains(t, got, "[exit=non-zero]") - }) - - t.Run("short_output", func(t *testing.T) { - content := "ok" - got := bashSummarizer(content, nil, false) - assertContains(t, got, "lines=1") - }) - - t.Run("empty_content", func(t *testing.T) { - got := bashSummarizer("", nil, false) - assertContains(t, got, "[exit=0]") - }) - - t.Run("sanitizes_workdir_metadata", func(t *testing.T) { - meta := stubMetadata("workdir", " \n\t/tmp/proj\x07 ") - got := bashSummarizer("ok", meta, false) - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected sanitized workdir without control characters, got %q", got) - } - assertContains(t, got, "workdir=/tmp/proj") - }) -} - -func TestReadFileSummarizer(t *testing.T) { - t.Parallel() - - t.Run("normal_file", func(t *testing.T) { - content := "package main\n\nfunc main() {\n\tfmt.Println(\"hello\")\n}\n" - meta := stubMetadata("path", "/home/user/main.go") - got := readFileSummarizer(content, meta, false) - assertContains(t, got, "/home/user/main.go") - assertContains(t, got, "lines=5") - assertContains(t, got, "chars=") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("trailing_newline_not_counted_as_extra_line", func(t *testing.T) { - content := "a\nb\n" - meta := stubMetadata("path", "/tmp/a.txt") - got := readFileSummarizer(content, meta, false) - assertContains(t, got, "lines=2") - }) - - t.Run("empty_lines_are_counted", func(t *testing.T) { - content := "\n\n" - meta := stubMetadata("path", "/tmp/empty.txt") - got := readFileSummarizer(content, meta, false) - assertContains(t, got, "lines=2") - }) - - t.Run("missing_path", func(t *testing.T) { - got := readFileSummarizer("content", nil, false) - assertEmptySummary(t, got) - }) - - t.Run("sanitizes_path_metadata", func(t *testing.T) { - content := "line1\nline2" - meta := stubMetadata("path", " \n\t/tmp/a.go\x07 ") - got := readFileSummarizer(content, meta, false) - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected sanitized path without control characters, got %q", got) - } - assertContains(t, got, "/tmp/a.go") - }) -} - -func TestWriteFileSummarizer(t *testing.T) { - t.Parallel() - - t.Run("normal", func(t *testing.T) { - meta := stubMetadata("path", "/home/user/test.go", "bytes", "1024") - got := writeFileSummarizer("", meta, false) - assertContains(t, got, "/home/user/test.go") - assertContains(t, got, "1024 bytes") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("missing_path", func(t *testing.T) { - got := writeFileSummarizer("", stubMetadata("bytes", "100"), false) - assertEmptySummary(t, got) - }) - - t.Run("sanitizes_path_metadata", func(t *testing.T) { - meta := stubMetadata("path", " \n\t/tmp/out.go\x07 ", "bytes", "4") - got := writeFileSummarizer("", meta, false) - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected sanitized path without control characters, got %q", got) - } - assertContains(t, got, "/tmp/out.go") - }) -} - -func TestEditSummarizer(t *testing.T) { - t.Parallel() - - t.Run("with_relative_path", func(t *testing.T) { - meta := stubMetadata("relative_path", "src/main.go", "path", "/abs/src/main.go", "search_length", "50", "replacement_length", "60") - got := editSummarizer("", meta, false) - assertContains(t, got, "src/main.go") - assertContains(t, got, "search=50") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("fallback_to_abs_path", func(t *testing.T) { - meta := stubMetadata("path", "/abs/src/main.go", "search_length", "10", "replacement_length", "20") - got := editSummarizer("", meta, false) - assertContains(t, got, "/abs/src/main.go") - }) - - t.Run("missing_path", func(t *testing.T) { - got := editSummarizer("", stubMetadata("search_length", "10"), false) - assertEmptySummary(t, got) - }) - - t.Run("sanitizes_path_metadata", func(t *testing.T) { - meta := stubMetadata("relative_path", " \n\tsrc/main.go\x07 ", "search_length", "10", "replacement_length", "12") - got := editSummarizer("", meta, false) - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected sanitized path without control characters, got %q", got) - } - assertContains(t, got, "src/main.go") - }) - - t.Run("long_path_is_truncated", func(t *testing.T) { - longPath := strings.Repeat("abcdef/", 80) + "main.go" - meta := stubMetadata("path", longPath, "search_length", "10", "replacement_length", "20") - got := editSummarizer("", meta, false) - assertMaxRuneCount(t, got, summaryMaxRunes+3) - }) -} - -func TestGrepSummarizer(t *testing.T) { - t.Parallel() - - t.Run("with_matches", func(t *testing.T) { - content := "src/a.go:10:match1\nsrc/b.go:20:match2\nsrc/c.go:30:match3\nsrc/d.go:40:match4" - meta := stubMetadata("root", "/home/user", "matched_files", "4", "matched_lines", "4") - got := grepSummarizer(content, meta, false) - assertContains(t, got, "root=/home/user") - assertContains(t, got, "files=4") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("empty_content", func(t *testing.T) { - meta := stubMetadata("root", "/home", "matched_files", "0", "matched_lines", "0") - got := grepSummarizer("", meta, false) - assertContains(t, got, "files=0") - }) - - t.Run("sanitizes_root_metadata", func(t *testing.T) { - content := "a.go:1:x" - meta := stubMetadata("root", " \n\t/tmp/root\x07 ", "matched_files", "1", "matched_lines", "1") - got := grepSummarizer(content, meta, false) - if strings.ContainsAny(got, "\n\t\a") { - t.Fatalf("expected sanitized root without control characters, got %q", got) - } - assertContains(t, got, "root=/tmp/root") - }) - - t.Run("sanitizes_injected_filename", func(t *testing.T) { - content := "src/a.go\nignore:1:x\nsafe.go:2:y" - meta := stubMetadata("matched_files", "2", "matched_lines", "2") - got := grepSummarizer(content, meta, false) - if strings.Contains(got, "\n") || strings.Contains(got, "\t") { - t.Fatalf("expected sanitized summary without control characters, got %q", got) - } - assertContains(t, got, "matches=ignore, safe.go") - }) -} - -func TestGlobSummarizer(t *testing.T) { - t.Parallel() - - t.Run("with_files", func(t *testing.T) { - content := "src/a.go\nsrc/b.go\nsrc/c.go\nsrc/d.go" - meta := stubMetadata("count", "4") - got := globSummarizer(content, meta, false) - assertContains(t, got, "4 files") - assertMaxRuneCount(t, got, summaryMaxRunes) - }) - - t.Run("no_matches", func(t *testing.T) { - meta := stubMetadata("count", "0") - got := globSummarizer("", meta, false) - assertContains(t, got, "0 files") - }) - - t.Run("skips_blank_and_control_lines", func(t *testing.T) { - content := "\n\t\nsrc/a.go\nsrc/b.go\n" - meta := stubMetadata("count", "2") - got := globSummarizer(content, meta, false) - assertContains(t, got, "src/a.go, src/b.go") - if strings.Contains(got, "\n") || strings.Contains(got, "\t") { - t.Fatalf("expected sanitized preview, got %q", got) - } - }) -} - -func TestWebfetchSummarizer(t *testing.T) { - t.Parallel() - - t.Run("with_truncated_flag", func(t *testing.T) { - meta := stubMetadata("truncated", "true") - got := webfetchSummarizer("", meta, false) - assertContains(t, got, "truncated=true") - }) - - t.Run("minimal", func(t *testing.T) { - got := webfetchSummarizer("", nil, false) - assertContains(t, got, "[summary] webfetch") - }) -} - -func TestRegisterBuiltinSummarizers(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - RegisterBuiltinSummarizers(registry) - - toolNames := []string{ - ToolNameBash, ToolNameFilesystemReadFile, ToolNameCodebaseRead, ToolNameFilesystemWriteFile, - ToolNameFilesystemEdit, ToolNameFilesystemGrep, ToolNameFilesystemGlob, - ToolNameWebFetch, - } - for _, name := range toolNames { - if registry.MicroCompactSummarizer(name) == nil { - t.Errorf("expected summarizer for %q to be registered", name) - } - } - - // 不在注册列表中的工具应返回 nil - if registry.MicroCompactSummarizer("unknown_tool") != nil { - t.Fatal("expected nil for unknown tool") - } -} - -func TestRegisterBuiltinSummarizersNilRegistry(t *testing.T) { - t.Parallel() - RegisterBuiltinSummarizers(nil) -} - -func TestRegisterSummarizer(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - - // 注册 - called := false - registry.RegisterSummarizer("test_tool", func(content string, metadata map[string]string, isError bool) string { - called = true - return "summary" - }) - - s := registry.MicroCompactSummarizer("test_tool") - if s == nil { - t.Fatal("expected summarizer to be registered") - } - result := s("content", nil, false) - if !called { - t.Fatal("expected summarizer to be called") - } - if result != "summary" { - t.Fatalf("expected 'summary', got %q", result) - } - - // 移除 - registry.RegisterSummarizer("test_tool", nil) - if registry.MicroCompactSummarizer("test_tool") != nil { - t.Fatal("expected nil after removal") - } -} - -func TestRegisterSummarizerNormalizesName(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - registry.RegisterSummarizer(" Mixed_Tool ", func(content string, metadata map[string]string, isError bool) string { - return "ok" - }) - - if registry.MicroCompactSummarizer("mixed_tool") == nil { - t.Fatal("expected normalized summarizer lookup") - } - if registry.MicroCompactSummarizer(" MIXED_TOOL ") == nil { - t.Fatal("expected case-insensitive summarizer lookup") - } -} - -func TestRegisterSummarizerNilRegistry(t *testing.T) { - t.Parallel() - - var nilRegistry *Registry - nilRegistry.RegisterSummarizer("tool", func(content string, metadata map[string]string, isError bool) string { - return "ok" - }) - if nilRegistry.MicroCompactSummarizer("tool") != nil { - t.Fatal("expected nil summarizer on nil registry") - } -} - -func TestRegisterSummarizerConcurrentAccess(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - var wg sync.WaitGroup - - for i := 0; i < 8; i++ { - wg.Add(1) - go func() { - defer wg.Done() - for j := 0; j < 200; j++ { - if j%3 == 0 { - registry.RegisterSummarizer("concurrent_tool", nil) - continue - } - registry.RegisterSummarizer("concurrent_tool", func(content string, metadata map[string]string, isError bool) string { - return "worker" - }) - s := registry.MicroCompactSummarizer("concurrent_tool") - if s != nil { - _ = s("content", nil, false) - } - } - }() - } - - wg.Wait() -} - -func TestTruncateRunes(t *testing.T) { - t.Parallel() - - t.Run("short", func(t *testing.T) { - got := truncateRunes("hello", 10) - if got != "hello" { - t.Fatalf("expected unchanged, got %q", got) - } - }) - - t.Run("exact", func(t *testing.T) { - got := truncateRunes("hello", 5) - if got != "hello" { - t.Fatalf("expected unchanged, got %q", got) - } - }) - - t.Run("truncated", func(t *testing.T) { - got := truncateRunes("hello world", 5) - if got != "hello..." { - t.Fatalf("expected 'hello...', got %q", got) - } - }) - - t.Run("chinese", func(t *testing.T) { - got := truncateRunes("你好世界测试", 3) - if got != "你好世..." { - t.Fatalf("expected '你好世...', got %q", got) - } - }) - - t.Run("zero_limit_keeps_original", func(t *testing.T) { - got := truncateRunes("hello", 0) - if got != "hello" { - t.Fatalf("expected unchanged with zero limit, got %q", got) - } - }) - - t.Run("empty_text", func(t *testing.T) { - got := truncateRunes("", 10) - if got != "" { - t.Fatalf("expected empty string, got %q", got) - } - }) -} - -func TestStableLineCount(t *testing.T) { - t.Parallel() - - t.Run("empty", func(t *testing.T) { - if got := stableLineCount(""); got != 0 { - t.Fatalf("expected 0, got %d", got) - } - }) - - t.Run("non_empty", func(t *testing.T) { - if got := stableLineCount("a\nb"); got != 2 { - t.Fatalf("expected 2, got %d", got) - } - }) - - t.Run("trailing_newline", func(t *testing.T) { - if got := stableLineCount("a\nb\n"); got != 2 { - t.Fatalf("expected 2, got %d", got) - } - }) - - t.Run("only_empty_lines", func(t *testing.T) { - if got := stableLineCount("\n\n"); got != 2 { - t.Fatalf("expected 2, got %d", got) - } - }) -} diff --git a/internal/tools/micro_compact_summarizers_builtin.go b/internal/tools/micro_compact_summarizers_builtin.go deleted file mode 100644 index de8c4886f..000000000 --- a/internal/tools/micro_compact_summarizers_builtin.go +++ /dev/null @@ -1,282 +0,0 @@ -package tools - -import ( - "strconv" - "strings" - "unicode/utf8" -) - -type builtinSummarizerRegistration struct { - toolName string - summarizer ContentSummarizer -} - -var builtinSummarizers = []builtinSummarizerRegistration{ - {toolName: ToolNameBash, summarizer: bashSummarizer}, - {toolName: ToolNameFilesystemReadFile, summarizer: readFileSummarizer}, - {toolName: ToolNameCodebaseRead, summarizer: readFileSummarizer}, - {toolName: ToolNameFilesystemWriteFile, summarizer: writeFileSummarizer}, - {toolName: ToolNameFilesystemEdit, summarizer: editSummarizer}, - {toolName: ToolNameFilesystemGrep, summarizer: grepSummarizer}, - {toolName: ToolNameFilesystemGlob, summarizer: globSummarizer}, - {toolName: ToolNameWebFetch, summarizer: webfetchSummarizer}, -} - -// RegisterBuiltinSummarizers 将所有内置工具的内容摘要器注册到 Registry。 -// 建议在启动装配阶段调用;可重复调用并覆盖同名摘要器。 -func RegisterBuiltinSummarizers(registry *Registry) { - if registry == nil { - return - } - for _, item := range builtinSummarizers { - registry.RegisterSummarizer(item.toolName, item.summarizer) - } -} - -const summaryMaxRunes = 200 -const metadataTokenMaxRunes = 120 - -// bashSummarizer 仅保留结构化执行元信息,避免把原始输出内容重新注入上下文。 -func bashSummarizer(content string, metadata map[string]string, isError bool) string { - var parts []string - - if isError { - parts = append(parts, "[exit=non-zero]") - } else { - parts = append(parts, "[exit=0]") - } - - if workdir := metadataToken(metadata["workdir"]); workdir != "" { - parts = append(parts, "workdir="+workdir) - } - - trimmed := strings.TrimSpace(content) - if trimmed != "" { - parts = appendTextStats(parts, trimmed) - } - - return truncateRunes(strings.Join(parts, " "), summaryMaxRunes) -} - -// readFileSummarizer 仅保留稳定元信息,避免在摘要中再次暴露文件正文。 -func readFileSummarizer(content string, metadata map[string]string, isError bool) string { - path := metadataToken(metadata["path"]) - if path == "" { - return "" - } - - lineCount := stableLineCount(content) - - var parts []string - parts = append(parts, "[summary]", path, "lines="+strconv.Itoa(lineCount)) - if content != "" { - parts = append(parts, "chars="+strconv.Itoa(utf8.RuneCountInString(content))) - } - - return truncateRunes(strings.Join(parts, " "), summaryMaxRunes) -} - -// writeFileSummarizer 保留文件路径与写入字节数。 -func writeFileSummarizer(content string, metadata map[string]string, isError bool) string { - path := metadataToken(metadata["path"]) - if path == "" { - return "" - } - bytes := metadata["bytes"] - return truncateRunes("[summary] wrote "+path+" ("+bytes+" bytes)", summaryMaxRunes) -} - -// editSummarizer 保留编辑路径与替换范围。 -func editSummarizer(content string, metadata map[string]string, isError bool) string { - path := metadataToken(metadata["relative_path"]) - if path == "" { - path = metadataToken(metadata["path"]) - } - if path == "" { - return "" - } - searchLen := metadata["search_length"] - replaceLen := metadata["replacement_length"] - return truncateRunes( - "[summary] edited "+path+" (search="+searchLen+" chars, replace="+replaceLen+" chars)", - summaryMaxRunes, - ) -} - -// grepSummarizer 保留搜索根目录、匹配计数与前若干文件名。 -func grepSummarizer(content string, metadata map[string]string, isError bool) string { - var parts []string - parts = append(parts, "[summary] grep") - - if root := metadataToken(metadata["root"]); root != "" { - parts = append(parts, "root="+root) - } - - if matchedFiles := metadata["matched_files"]; matchedFiles != "" { - parts = append(parts, "files="+matchedFiles) - } - if matchedLines := metadata["matched_lines"]; matchedLines != "" { - parts = append(parts, "lines="+matchedLines) - } - - // 从 content 中提取前几个不重复文件名,避免对整段输出做全量切分。 - fileNames := extractUniqueMatchFiles(content, 3) - if len(fileNames) > 0 { - parts = append(parts, "matches="+strings.Join(fileNames, ", ")) - } - - return truncateRunes(strings.Join(parts, " "), summaryMaxRunes) -} - -// globSummarizer 保留匹配计数与前若干文件名。 -func globSummarizer(content string, metadata map[string]string, isError bool) string { - count := metadata["count"] - if count == "" { - count = "?" - } - - preview := collectPreviewLines(content, 3) - - var parts []string - parts = append(parts, "[summary] glob", count+" files") - if len(preview) > 0 { - parts = append(parts, strings.Join(preview, ", ")) - } - - return truncateRunes(strings.Join(parts, " "), summaryMaxRunes) -} - -// webfetchSummarizer 保留可稳定持久化的 webfetch 结果标记。 -func webfetchSummarizer(content string, metadata map[string]string, isError bool) string { - var parts []string - parts = append(parts, "[summary] webfetch") - - if truncated := metadata["truncated"]; truncated == "true" { - parts = append(parts, "truncated=true") - } - - return truncateRunes(strings.Join(parts, " "), summaryMaxRunes) -} - -// truncateRunes 按 rune 数量截断字符串,超出时追加 "..."。 -func truncateRunes(text string, maxRunes int) string { - if maxRunes <= 0 || text == "" { - return text - } - if utf8.RuneCountInString(text) <= maxRunes { - return text - } - runes := []rune(text) - return string(runes[:maxRunes]) + "..." -} - -// stableLineCount 统计文本行数;空文本返回 0,末尾换行不会产生额外空行计数。 -func stableLineCount(text string) int { - if text == "" { - return 0 - } - count := strings.Count(text, "\n") + 1 - if strings.HasSuffix(text, "\n") { - count-- - } - if count < 0 { - return 0 - } - return count -} - -// appendTextStats 为摘要补充文本统计字段,保持统一的结构化输出格式。 -func appendTextStats(parts []string, text string) []string { - return append(parts, - "lines="+strconv.Itoa(stableLineCount(text)), - "chars="+strconv.Itoa(utf8.RuneCountInString(text)), - ) -} - -// extractUniqueMatchFiles 按行扫描 grep 输出,提取前若干个去重后的文件名摘要。 -func extractUniqueMatchFiles(content string, limit int) []string { - if limit <= 0 { - return nil - } - - seen := make(map[string]struct{}, limit) - result := make([]string, 0, limit) - remaining := content - for len(remaining) > 0 && len(result) < limit { - line, rest := nextLine(remaining) - remaining = rest - - colon := strings.Index(line, ":") - if colon <= 0 { - continue - } - - file := sanitizeSummaryToken(line[:colon], 80) - if file == "" { - continue - } - if _, ok := seen[file]; ok { - continue - } - seen[file] = struct{}{} - result = append(result, file) - } - return result -} - -// collectPreviewLines 按行扫描输出并提取前若干个非空预览,避免全量 Split 带来的额外分配。 -func collectPreviewLines(content string, limit int) []string { - if limit <= 0 { - return nil - } - - result := make([]string, 0, limit) - remaining := content - for len(remaining) > 0 && len(result) < limit { - line, rest := nextLine(remaining) - remaining = rest - - clean := sanitizeSummaryToken(line, 100) - if clean == "" { - continue - } - result = append(result, clean) - } - return result -} - -// nextLine 返回 text 的首行及余下文本,兼容存在或不存在换行符的输入。 -func nextLine(text string) (line string, rest string) { - idx := strings.IndexByte(text, '\n') - if idx < 0 { - return text, "" - } - return text[:idx], text[idx+1:] -} - -// sanitizeSummaryToken 清理不可见控制字符并裁剪长度,降低摘要注入风险。 -func sanitizeSummaryToken(text string, maxRunes int) string { - trimmed := strings.TrimSpace(text) - if trimmed == "" { - return "" - } - - var b strings.Builder - b.Grow(len(trimmed)) - for _, r := range trimmed { - if r < 32 || r == 127 { - continue - } - b.WriteRune(r) - } - clean := strings.TrimSpace(b.String()) - if clean == "" { - return "" - } - return truncateRunes(clean, maxRunes) -} - -// metadataToken 统一清理 metadata 中可回灌到摘要的文本字段。 -func metadataToken(text string) string { - return sanitizeSummaryToken(text, metadataTokenMaxRunes) -} diff --git a/internal/tools/names.go b/internal/tools/names.go index a6d5dc91d..bcad15443 100644 --- a/internal/tools/names.go +++ b/internal/tools/names.go @@ -9,11 +9,7 @@ const ( ToolNameFilesystemGrep = "filesystem_grep" ToolNameFilesystemGlob = "filesystem_glob" ToolNameFilesystemEdit = "filesystem_edit" - ToolNameFilesystemMoveFile = "filesystem_move_file" - ToolNameFilesystemCopyFile = "filesystem_copy_file" ToolNameFilesystemDeleteFile = "filesystem_delete_file" - ToolNameFilesystemCreateDir = "filesystem_create_dir" - ToolNameFilesystemRemoveDir = "filesystem_remove_dir" ToolNameTodoWrite = "todo_write" ToolNameSpawnSubAgent = "spawn_subagent" ToolNameMemoRemember = "memo_remember" diff --git a/internal/tools/permission_mapper.go b/internal/tools/permission_mapper.go index 0dd71f922..7b7a820e3 100644 --- a/internal/tools/permission_mapper.go +++ b/internal/tools/permission_mapper.go @@ -113,20 +113,6 @@ func buildPermissionAction(input ToolCallInput) (security.Action, error) { action.Payload.Target = extractStringArgument(input.Arguments, "path") action.Payload.SandboxTargetType = security.TargetTypePath action.Payload.SandboxTarget = action.Payload.Target - case ToolNameFilesystemMoveFile: - action.Type = security.ActionTypeWrite - action.Payload.Operation = "move_file" - action.Payload.TargetType = security.TargetTypePath - action.Payload.Target = extractStringArgument(input.Arguments, "destination_path") - action.Payload.SandboxTargetType = security.TargetTypePath - action.Payload.SandboxTarget = action.Payload.Target - case ToolNameFilesystemCopyFile: - action.Type = security.ActionTypeWrite - action.Payload.Operation = "copy_file" - action.Payload.TargetType = security.TargetTypePath - action.Payload.Target = extractStringArgument(input.Arguments, "destination_path") - action.Payload.SandboxTargetType = security.TargetTypePath - action.Payload.SandboxTarget = action.Payload.Target case ToolNameFilesystemDeleteFile: action.Type = security.ActionTypeWrite action.Payload.Operation = "delete_file" @@ -134,20 +120,6 @@ func buildPermissionAction(input ToolCallInput) (security.Action, error) { action.Payload.Target = extractStringArgument(input.Arguments, "path") action.Payload.SandboxTargetType = security.TargetTypePath action.Payload.SandboxTarget = action.Payload.Target - case ToolNameFilesystemCreateDir: - action.Type = security.ActionTypeWrite - action.Payload.Operation = "create_dir" - action.Payload.TargetType = security.TargetTypePath - action.Payload.Target = extractStringArgument(input.Arguments, "path") - action.Payload.SandboxTargetType = security.TargetTypePath - action.Payload.SandboxTarget = action.Payload.Target - case ToolNameFilesystemRemoveDir: - action.Type = security.ActionTypeWrite - action.Payload.Operation = "remove_dir" - action.Payload.TargetType = security.TargetTypePath - action.Payload.Target = extractStringArgument(input.Arguments, "path") - action.Payload.SandboxTargetType = security.TargetTypePath - action.Payload.SandboxTarget = action.Payload.Target case ToolNameTodoWrite: action.Type = security.ActionTypeWrite action.Payload.Operation = "todo_write" diff --git a/internal/tools/registry.go b/internal/tools/registry.go index 16aee9168..99d8d8659 100644 --- a/internal/tools/registry.go +++ b/internal/tools/registry.go @@ -13,22 +13,17 @@ import ( ) type Registry struct { - tools map[string]Tool - microCompactPolicies map[string]MicroCompactPolicy - microCompactSummarizers map[string]ContentSummarizer - microCompactSummaryMu sync.RWMutex - mcpMu sync.RWMutex - mcpRegistry *mcp.Registry - mcpFactory *mcp.AdapterFactory - mcpExposureFilter mcp.ExposureFilter - mcpExposureAudit []mcp.ExposureDecision + tools map[string]Tool + mcpMu sync.RWMutex + mcpRegistry *mcp.Registry + mcpFactory *mcp.AdapterFactory + mcpExposureFilter mcp.ExposureFilter + mcpExposureAudit []mcp.ExposureDecision } func NewRegistry() *Registry { return &Registry{ - tools: map[string]Tool{}, - microCompactPolicies: map[string]MicroCompactPolicy{}, - microCompactSummarizers: map[string]ContentSummarizer{}, + tools: map[string]Tool{}, } } @@ -95,12 +90,6 @@ func (r *Registry) Register(tool Tool) { } name := strings.ToLower(tool.Name()) r.tools[name] = tool - switch tool.MicroCompactPolicy() { - case MicroCompactPolicyPreserveHistory: - r.microCompactPolicies[name] = MicroCompactPolicyPreserveHistory - default: - r.microCompactPolicies[name] = MicroCompactPolicyCompact - } } func (r *Registry) Get(name string) (Tool, error) { @@ -119,49 +108,6 @@ func (r *Registry) Supports(name string) bool { return r.supportsMCPTool(name) } -// MicroCompactPolicy 返回指定工具的 micro compact 策略;未知工具按默认可压缩处理。 -func (r *Registry) MicroCompactPolicy(name string) MicroCompactPolicy { - if r == nil { - return MicroCompactPolicyCompact - } - policy, ok := r.microCompactPolicies[strings.ToLower(strings.TrimSpace(name))] - if !ok { - return MicroCompactPolicyCompact - } - if policy == MicroCompactPolicyPreserveHistory { - return MicroCompactPolicyPreserveHistory - } - return MicroCompactPolicyCompact -} - -// RegisterSummarizer 为指定工具注册内容摘要器;传入 nil 移除已有条目。 -func (r *Registry) RegisterSummarizer(toolName string, summarizer ContentSummarizer) { - if r == nil { - return - } - name := strings.ToLower(strings.TrimSpace(toolName)) - r.microCompactSummaryMu.Lock() - defer r.microCompactSummaryMu.Unlock() - if summarizer == nil { - delete(r.microCompactSummarizers, name) - return - } - r.microCompactSummarizers[name] = summarizer -} - -// MicroCompactSummarizer 返回指定工具的内容摘要器;无注册时返回 nil。 -func (r *Registry) MicroCompactSummarizer(name string) ContentSummarizer { - if r == nil { - return nil - } - r.microCompactSummaryMu.RLock() - defer r.microCompactSummaryMu.RUnlock() - if r.microCompactSummarizers == nil { - return nil - } - return r.microCompactSummarizers[strings.ToLower(strings.TrimSpace(name))] -} - func (r *Registry) GetSpecs() []providertypes.ToolSpec { names := make([]string, 0, len(r.tools)) for name := range r.tools { diff --git a/internal/tools/registry_test.go b/internal/tools/registry_test.go index 1191317cc..0d18b3c1f 100644 --- a/internal/tools/registry_test.go +++ b/internal/tools/registry_test.go @@ -14,7 +14,6 @@ type stubTool struct { name string description string schema map[string]any - policy MicroCompactPolicy result ToolResult err error } @@ -24,9 +23,6 @@ func (s stubTool) Description() string { return s.description } func (s stubTool) Schema() map[string]any { return s.schema } -func (s stubTool) MicroCompactPolicy() MicroCompactPolicy { - return s.policy -} func (s stubTool) Execute(ctx context.Context, call ToolCallInput) (ToolResult, error) { return s.result, s.err } @@ -166,12 +162,6 @@ func TestRegistryHelpers(t *testing.T) { if registry.Supports("missing") { t.Fatalf("did not expect registry to support missing tool") } - if registry.MicroCompactPolicy("a_tool") != MicroCompactPolicyCompact { - t.Fatalf("expected compact policy default for a_tool") - } - if registry.MicroCompactPolicy("missing") != MicroCompactPolicyCompact { - t.Fatalf("expected compact policy default for unknown tool") - } schemas := registry.ListSchemas() if len(schemas) != 1 || schemas[0].Name != "a_tool" { @@ -194,43 +184,6 @@ func TestRegistryHelpers(t *testing.T) { } } -func TestRegistryMicroCompactPolicyPreserveHistory(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - registry.Register(stubTool{ - name: "custom_tool", - description: "preserve history", - schema: map[string]any{"type": "object"}, - policy: MicroCompactPolicyPreserveHistory, - }) - - if got := registry.MicroCompactPolicy("custom_tool"); got != MicroCompactPolicyPreserveHistory { - t.Fatalf("expected preserve history policy, got %q", got) - } -} - -func TestRegistryMicroCompactPolicyNormalizesNameAndNilRegistry(t *testing.T) { - t.Parallel() - - registry := NewRegistry() - registry.Register(stubTool{ - name: "Custom_Tool", - description: "preserve history", - schema: map[string]any{"type": "object"}, - policy: MicroCompactPolicyPreserveHistory, - }) - - if got := registry.MicroCompactPolicy(" custom_tool "); got != MicroCompactPolicyPreserveHistory { - t.Fatalf("expected normalized preserve history policy, got %q", got) - } - - var nilRegistry *Registry - if got := nilRegistry.MicroCompactPolicy("whatever"); got != MicroCompactPolicyCompact { - t.Fatalf("expected nil registry default compact policy, got %q", got) - } -} - func TestRegistryRememberSessionDecisionUnsupported(t *testing.T) { t.Parallel() diff --git a/internal/tools/spawnsubagent/tool.go b/internal/tools/spawnsubagent/tool.go index 95c07335c..7b472f0e7 100644 --- a/internal/tools/spawnsubagent/tool.go +++ b/internal/tools/spawnsubagent/tool.go @@ -7,12 +7,12 @@ import ( "encoding/json" "errors" "fmt" + "neo-code/internal/tools" "path/filepath" "strings" "time" "neo-code/internal/subagent" - "neo-code/internal/tools" ) const ( @@ -113,11 +113,6 @@ func (t *Tool) Schema() map[string]any { } } -// MicroCompactPolicy 保留子代理结果,避免短期压缩时丢失分析链路与结论。 -func (t *Tool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyPreserveHistory -} - // Execute 解析入参后执行 inline 模式。 func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { if err := ctx.Err(); err != nil { diff --git a/internal/tools/spawnsubagent/tool_test.go b/internal/tools/spawnsubagent/tool_test.go index c06c1b1bd..a551c5157 100644 --- a/internal/tools/spawnsubagent/tool_test.go +++ b/internal/tools/spawnsubagent/tool_test.go @@ -38,9 +38,6 @@ func TestToolMetadata(t *testing.T) { if strings.TrimSpace(tool.Description()) == "" { t.Fatalf("Description() should not be empty") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyPreserveHistory { - t.Fatalf("MicroCompactPolicy() = %q, want %q", tool.MicroCompactPolicy(), tools.MicroCompactPolicyPreserveHistory) - } schema := tool.Schema() properties, ok := schema["properties"].(map[string]any) if !ok { diff --git a/internal/tools/todo/write.go b/internal/tools/todo/write.go index 716d64788..693965d85 100644 --- a/internal/tools/todo/write.go +++ b/internal/tools/todo/write.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" + "neo-code/internal/tools" "strings" agentsession "neo-code/internal/session" - "neo-code/internal/tools" ) // Tool 是会话级 Todo 读写工具实现。 @@ -263,11 +263,6 @@ func (t *Tool) Schema() map[string]any { } } -// MicroCompactPolicy 返回工具微压缩策略。 -func (t *Tool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - // Execute 执行 todo_write 的 action 分发,并把变更写回会话。 func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { if err := ctx.Err(); err != nil { diff --git a/internal/tools/todo/write_test.go b/internal/tools/todo/write_test.go index 0c11e6dfb..a944dae89 100644 --- a/internal/tools/todo/write_test.go +++ b/internal/tools/todo/write_test.go @@ -228,9 +228,6 @@ func TestToolMetadataMethods(t *testing.T) { if strings.TrimSpace(tool.Description()) == "" { t.Fatalf("Description() should not be empty") } - if tool.MicroCompactPolicy() != tools.MicroCompactPolicyCompact { - t.Fatalf("MicroCompactPolicy() should be compact") - } schema := tool.Schema() if schema["type"] != "object" { t.Fatalf("Schema() type = %+v", schema["type"]) diff --git a/internal/tools/types.go b/internal/tools/types.go index 554308c4f..f3589f848 100644 --- a/internal/tools/types.go +++ b/internal/tools/types.go @@ -15,7 +15,6 @@ type Tool interface { Name() string Description() string Schema() map[string]any - MicroCompactPolicy() MicroCompactPolicy Execute(ctx context.Context, call ToolCallInput) (ToolResult, error) } diff --git a/internal/tools/webfetch/tool.go b/internal/tools/webfetch/tool.go index 3b353902a..44f4fb7b6 100644 --- a/internal/tools/webfetch/tool.go +++ b/internal/tools/webfetch/tool.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "mime" + "neo-code/internal/tools" "net" "net/http" "net/url" @@ -13,7 +14,6 @@ import ( "time" "neo-code/internal/config" - "neo-code/internal/tools" ) const ( @@ -116,11 +116,6 @@ func (t *Tool) Schema() map[string]any { } } -// MicroCompactPolicy 声明 webfetch 工具的历史结果默认参与 micro compact 清理。 -func (t *Tool) MicroCompactPolicy() tools.MicroCompactPolicy { - return tools.MicroCompactPolicyCompact -} - func (t *Tool) Execute(ctx context.Context, call tools.ToolCallInput) (tools.ToolResult, error) { in, err := decodeInput(call.Arguments) if err != nil { diff --git a/internal/tui/core/app/update.go b/internal/tui/core/app/update.go index 8c0fa8a7b..6ad4c32a1 100644 --- a/internal/tui/core/app/update.go +++ b/internal/tui/core/app/update.go @@ -3030,6 +3030,9 @@ var runtimeEventHandlerRegistry = map[tuiservices.EventType]func(*App, tuiservic tuiservices.EventCompactError: runtimeEventCompactErrorHandler, tuiservices.EventTokenUsage: runtimeEventTokenUsageHandler, tuiservices.EventPhaseChanged: runtimeEventPhaseChangedHandler, + tuiservices.EventVerificationStarted: runtimeEventVerificationStartedHandler, + tuiservices.EventVerificationStageFinished: runtimeEventVerificationStageFinishedHandler, + tuiservices.EventVerificationFinished: runtimeEventVerificationFinishedHandler, tuiservices.EventVerificationCompleted: runtimeEventVerificationCompletedHandler, tuiservices.EventVerificationFailed: runtimeEventVerificationFailedHandler, tuiservices.EventAcceptanceDecided: runtimeEventAcceptanceDecidedHandler, @@ -3525,6 +3528,75 @@ func runtimeEventPhaseChangedHandler(a *App, event tuiservices.RuntimeEvent) boo return false } +// runtimeEventVerificationStartedHandler 处理验证流程开始事件并记录 completion gate 结果。 +func runtimeEventVerificationStartedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.VerificationStartedPayload) + if !ok { + return false + } + detail := "completion_gate=pass" + if !payload.CompletionPassed { + detail = "completion_gate=blocked" + } + if reason := strings.TrimSpace(payload.CompletionBlockedReason); reason != "" { + detail = detail + " (" + reason + ")" + } + a.appendActivity("verify", "Verification started", detail, false) + return false +} + +// runtimeEventVerificationStageFinishedHandler 处理单个验证阶段完成事件并展示结果摘要。 +func runtimeEventVerificationStageFinishedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.VerificationStageFinishedPayload) + if !ok { + return false + } + stageName := strings.TrimSpace(payload.Name) + if stageName == "" { + stageName = "unknown_stage" + } + status := strings.ToLower(strings.TrimSpace(payload.Status)) + title := "Verification stage passed" + isError := false + if status != "pass" { + title = "Verification stage failed" + isError = true + } + detail := stageName + if summary := strings.TrimSpace(payload.Summary); summary != "" { + detail = detail + " | " + summary + } else if reason := strings.TrimSpace(payload.Reason); reason != "" { + detail = detail + " | " + reason + } + if class := strings.TrimSpace(payload.ErrorClass); class != "" { + detail = detail + " | class=" + class + } + a.appendActivity("verify", title, detail, isError) + return false +} + +// runtimeEventVerificationFinishedHandler 处理验证流程结束事件并输出最终 acceptance 状态。 +func runtimeEventVerificationFinishedHandler(a *App, event tuiservices.RuntimeEvent) bool { + payload, ok := event.Payload.(tuiservices.VerificationFinishedPayload) + if !ok { + return false + } + acceptanceStatus := strings.TrimSpace(payload.AcceptanceStatus) + if acceptanceStatus == "" { + acceptanceStatus = "unknown" + } + detail := "acceptance_status=" + acceptanceStatus + if reason := strings.TrimSpace(string(payload.StopReason)); reason != "" { + detail = detail + " | stop=" + reason + } + if class := strings.TrimSpace(payload.ErrorClass); class != "" { + detail = detail + " | class=" + class + } + isError := strings.EqualFold(acceptanceStatus, "failed") + a.appendActivity("verify", "Verification finished", detail, isError) + return false +} + // runtimeEventVerificationCompletedHandler 处理验证通过事件。 func runtimeEventVerificationCompletedHandler(a *App, event tuiservices.RuntimeEvent) bool { payload, ok := event.Payload.(tuiservices.VerificationCompletedPayload) diff --git a/internal/tui/core/app/update_runtime_events_test.go b/internal/tui/core/app/update_runtime_events_test.go index 84016b035..488185e84 100644 --- a/internal/tui/core/app/update_runtime_events_test.go +++ b/internal/tui/core/app/update_runtime_events_test.go @@ -181,6 +181,15 @@ func TestRuntimeEventHandlerRegistryContainsRenamedEvents(t *testing.T) { if _, ok := runtimeEventHandlerRegistry[agentruntime.EventVerificationCompleted]; !ok { t.Fatalf("expected verification_completed handler to be registered") } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventVerificationStarted]; !ok { + t.Fatalf("expected verification_started handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventVerificationStageFinished]; !ok { + t.Fatalf("expected verification_stage_finished handler to be registered") + } + if _, ok := runtimeEventHandlerRegistry[agentruntime.EventVerificationFinished]; !ok { + t.Fatalf("expected verification_finished handler to be registered") + } if _, ok := runtimeEventHandlerRegistry[agentruntime.EventVerificationFailed]; !ok { t.Fatalf("expected verification_failed handler to be registered") } @@ -709,6 +718,61 @@ func TestRuntimeEventMultimodalHandlers(t *testing.T) { func TestRuntimeEventVerificationAndAcceptanceHandlers(t *testing.T) { app, _ := newTestApp(t) + if handled := runtimeEventVerificationStartedHandler(&app, agentruntime.RuntimeEvent{Payload: 1}); handled { + t.Fatalf("expected invalid verification_started payload to return false") + } + runtimeEventVerificationStartedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.VerificationStartedPayload{ + CompletionPassed: false, + CompletionBlockedReason: "pending_todo", + }, + }) + started := app.activities[len(app.activities)-1] + if started.Title != "Verification started" || !strings.Contains(started.Detail, "completion_gate=blocked") || started.IsError { + t.Fatalf("unexpected started activity: %+v", started) + } + + if handled := runtimeEventVerificationStageFinishedHandler(&app, agentruntime.RuntimeEvent{Payload: 1}); handled { + t.Fatalf("expected invalid verification_stage_finished payload to return false") + } + runtimeEventVerificationStageFinishedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.VerificationStageFinishedPayload{ + Name: "required_todo", + Status: "fail", + Reason: "todo remains open", + ErrorClass: "unknown", + }, + }) + stage := app.activities[len(app.activities)-1] + if stage.Title != "Verification stage failed" || !strings.Contains(stage.Detail, "required_todo") || !stage.IsError { + t.Fatalf("unexpected stage activity: %+v", stage) + } + + if handled := runtimeEventVerificationFinishedHandler(&app, agentruntime.RuntimeEvent{Payload: 1}); handled { + t.Fatalf("expected invalid verification_finished payload to return false") + } + runtimeEventVerificationFinishedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.VerificationFinishedPayload{ + AcceptanceStatus: "failed", + StopReason: agentruntime.StopReasonVerificationFailed, + ErrorClass: "unknown", + }, + }) + finished := app.activities[len(app.activities)-1] + if finished.Title != "Verification finished" || !strings.Contains(finished.Detail, "acceptance_status=failed") || !finished.IsError { + t.Fatalf("unexpected finished activity: %+v", finished) + } + runtimeEventVerificationFinishedHandler(&app, agentruntime.RuntimeEvent{ + Payload: agentruntime.VerificationFinishedPayload{ + AcceptanceStatus: "continued", + StopReason: agentruntime.StopReasonAcceptContinue, + }, + }) + continued := app.activities[len(app.activities)-1] + if continued.Title != "Verification finished" || !strings.Contains(continued.Detail, "acceptance_status=continued") || continued.IsError { + t.Fatalf("unexpected continued activity: %+v", continued) + } + if handled := runtimeEventVerificationCompletedHandler(&app, agentruntime.RuntimeEvent{Payload: 1}); handled { t.Fatalf("expected invalid verification_completed payload to return false") } diff --git a/internal/tui/services/gateway_stream_client.go b/internal/tui/services/gateway_stream_client.go index 92dd02281..86d7eb1f9 100644 --- a/internal/tui/services/gateway_stream_client.go +++ b/internal/tui/services/gateway_stream_client.go @@ -205,6 +205,12 @@ func restoreRuntimePayload(eventType EventType, payload any) (any, error) { return decodeRuntimePayload[PhaseChangedPayload](payload) case EventStopReasonDecided: return decodeStopReasonPayload(payload) + case EventVerificationStarted: + return decodeRuntimePayload[VerificationStartedPayload](payload) + case EventVerificationStageFinished: + return decodeRuntimePayload[VerificationStageFinishedPayload](payload) + case EventVerificationFinished: + return decodeRuntimePayload[VerificationFinishedPayload](payload) case EventVerificationCompleted: return decodeRuntimePayload[VerificationCompletedPayload](payload) case EventVerificationFailed: diff --git a/internal/tui/services/gateway_stream_client_additional_test.go b/internal/tui/services/gateway_stream_client_additional_test.go index 9edd3c7c0..6b6a4b0e3 100644 --- a/internal/tui/services/gateway_stream_client_additional_test.go +++ b/internal/tui/services/gateway_stream_client_additional_test.go @@ -169,6 +169,63 @@ func TestRestoreRuntimePayloadCoversSpecializedTypes(t *testing.T) { } }, }, + { + name: "verification started", + eventType: EventVerificationStarted, + payload: map[string]any{ + "completion_passed": false, + "completion_blocked_reason": "pending_todo", + }, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(VerificationStartedPayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.CompletionPassed || value.CompletionBlockedReason != "pending_todo" { + t.Fatalf("payload = %#v", value) + } + }, + }, + { + name: "verification stage finished", + eventType: EventVerificationStageFinished, + payload: map[string]any{ + "name": "required_todo", + "status": "fail", + "reason": "open todo remains", + "error_class": "unknown", + }, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(VerificationStageFinishedPayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.Name != "required_todo" || value.Status != "fail" || value.ErrorClass != "unknown" { + t.Fatalf("payload = %#v", value) + } + }, + }, + { + name: "verification finished", + eventType: EventVerificationFinished, + payload: map[string]any{ + "acceptance_status": "failed", + "stop_reason": "verification_failed", + "error_class": "unknown", + }, + assertFn: func(t *testing.T, got any) { + t.Helper() + value, ok := got.(VerificationFinishedPayload) + if !ok { + t.Fatalf("payload type = %T", got) + } + if value.AcceptanceStatus != "failed" || value.StopReason != StopReasonVerificationFailed || value.ErrorClass != "unknown" { + t.Fatalf("payload = %#v", value) + } + }, + }, { name: "runtime usage payload", eventType: EventType(RuntimeEventUsage), diff --git a/internal/tui/services/runtime_contract.go b/internal/tui/services/runtime_contract.go index d7480d47b..25acf7fc1 100644 --- a/internal/tui/services/runtime_contract.go +++ b/internal/tui/services/runtime_contract.go @@ -2,6 +2,7 @@ package services import ( "context" + "sort" "time" providertypes "neo-code/internal/provider/types" @@ -338,6 +339,12 @@ type StopReasonDecidedPayload struct { Detail string `json:"detail,omitempty"` } +// VerificationStartedPayload 描述验证流程启动事件。 +type VerificationStartedPayload struct { + CompletionPassed bool `json:"completion_passed"` + CompletionBlockedReason string `json:"completion_blocked_reason,omitempty"` +} + // VerificationStageFinishedPayload 描述单个 verifier 阶段结果。 type VerificationStageFinishedPayload struct { Name string `json:"name"` @@ -718,3 +725,116 @@ const ( EventDecisionMade EventType = "decision_made" EventTodoSnapshotUpdated EventType = "todo_snapshot_updated" ) + +// contractEntry 描述单个事件类型的契约声明。 +type contractEntry struct { + RequireConsumer bool +} + +// contractRegistry 声明 TUI 侧已知的事件类型及其消费者要求。 +// RequireConsumer=true 表示该事件必须有对应的 gateway decode 分支与 TUI 消费者; +// RequireConsumer=false 表示该事件允许透传(passthrough),不要求显式消费。 +var contractRegistry = map[EventType]contractEntry{ + // --- 已有 decode 分支的事件(RequireConsumer=true)--- + EventUserMessage: {RequireConsumer: true}, + EventAgentDone: {RequireConsumer: true}, + EventToolStart: {RequireConsumer: true}, + EventToolResult: {RequireConsumer: true}, + EventPermissionRequested: {RequireConsumer: true}, + EventPermissionResolved: {RequireConsumer: true}, + EventUserQuestionRequested: {RequireConsumer: true}, + EventUserQuestionAnswered: {RequireConsumer: true}, + EventUserQuestionTimeout: {RequireConsumer: true}, + EventUserQuestionSkipped: {RequireConsumer: true}, + EventCompactApplied: {RequireConsumer: true}, + EventCompactError: {RequireConsumer: true}, + EventTokenUsage: {RequireConsumer: true}, + EventPhaseChanged: {RequireConsumer: true}, + EventStopReasonDecided: {RequireConsumer: true}, + EventVerificationStarted: {RequireConsumer: true}, + EventVerificationStageFinished: {RequireConsumer: true}, + EventVerificationFinished: {RequireConsumer: true}, + EventVerificationCompleted: {RequireConsumer: true}, + EventVerificationFailed: {RequireConsumer: true}, + EventAcceptanceDecided: {RequireConsumer: true}, + EventInputNormalized: {RequireConsumer: true}, + EventAssetSaved: {RequireConsumer: true}, + EventAssetSaveFailed: {RequireConsumer: true}, + EventHookStarted: {RequireConsumer: true}, + EventHookFinished: {RequireConsumer: true}, + EventHookFailed: {RequireConsumer: true}, + EventHookBlocked: {RequireConsumer: true}, + EventHookNotification: {RequireConsumer: true}, + EventRepoHooksDiscovered: {RequireConsumer: true}, + EventRepoHooksLoaded: {RequireConsumer: true}, + EventRepoHooksSkippedUntrusted: {RequireConsumer: true}, + EventRepoHooksTrustStoreInvalid: {RequireConsumer: true}, + EventCheckpointCreated: {RequireConsumer: true}, + EventCheckpointWarning: {RequireConsumer: true}, + EventCheckpointRestored: {RequireConsumer: true}, + EventCheckpointUndoRestore: {RequireConsumer: true}, + EventToolDiff: {RequireConsumer: true}, + EventBashSideEffect: {RequireConsumer: true}, + EventTodoUpdated: {RequireConsumer: true}, + EventTodoConflict: {RequireConsumer: true}, + EventTodoSnapshotUpdated: {RequireConsumer: true}, + EventSubAgentStarted: {RequireConsumer: true}, + EventSubAgentProgress: {RequireConsumer: true}, + EventSubAgentRetried: {RequireConsumer: true}, + EventSubAgentBlocked: {RequireConsumer: true}, + EventSubAgentCompleted: {RequireConsumer: true}, + EventSubAgentFailed: {RequireConsumer: true}, + EventSubAgentCanceled: {RequireConsumer: true}, + EventSubAgentFinished: {RequireConsumer: true}, + EventSubAgentToolCallStarted: {RequireConsumer: true}, + EventSubAgentToolCallResult: {RequireConsumer: true}, + EventSubAgentToolCallDenied: {RequireConsumer: true}, + EventRuntimeSnapshotUpdated: {RequireConsumer: true}, + EventSubAgentSnapshotUpdated: {RequireConsumer: true}, + EventDecisionMade: {RequireConsumer: true}, + + // --- 字符串类 payload 事件(有 decode 分支,透传字符串)--- + EventAgentChunk: {RequireConsumer: true}, + EventToolChunk: {RequireConsumer: true}, + EventError: {RequireConsumer: true}, + EventToolCallThinking: {RequireConsumer: true}, + EventCompactStart: {RequireConsumer: true}, + + // --- 显式声明为透传安全(passthrough-safe)的事件 --- + // 这些事件在 runtime 侧产生但不要求 TUI 显式消费, + // 未在 gateway decode 中处理时会以原始 payload 透传。 + EventRunCanceled: {RequireConsumer: false}, + EventSkillActivated: {RequireConsumer: false}, + EventSkillDeactivated: {RequireConsumer: false}, + EventSkillMissing: {RequireConsumer: false}, + EventProgressEvaluated: {RequireConsumer: false}, + EventTodoSummaryInjected: {RequireConsumer: false}, +} + +// RegisteredEventTypes 返回所有已注册的契约事件类型(排序后)。 +func RegisteredEventTypes() []EventType { + types := make([]EventType, 0, len(contractRegistry)) + for eventType := range contractRegistry { + types = append(types, eventType) + } + sort.Slice(types, func(i, j int) bool { + return types[i] < types[j] + }) + return types +} + +// RequireConsumer 返回指定事件类型是否要求显式消费者。 +// 若事件类型未注册,返回 false(允许透传)。 +func RequireConsumer(eventType EventType) bool { + entry, ok := contractRegistry[eventType] + if !ok { + return false + } + return entry.RequireConsumer +} + +// IsRegisteredEventType 返回指定事件类型是否已注册到契约中。 +func IsRegisteredEventType(eventType EventType) bool { + _, ok := contractRegistry[eventType] + return ok +} diff --git a/internal/tui/services/runtime_contract_test.go b/internal/tui/services/runtime_contract_test.go new file mode 100644 index 000000000..1ab523ff7 --- /dev/null +++ b/internal/tui/services/runtime_contract_test.go @@ -0,0 +1,677 @@ +package services + +import ( + "go/ast" + "go/parser" + "go/token" + "os" + "path/filepath" + "sort" + "strings" + "testing" +) + +// runtimeContractEventSourceFiles 定义 runtime 事件常量的源文件列表。 +var runtimeContractEventSourceFiles = []string{ + "internal/runtime/events.go", + "internal/runtime/events_subagent.go", +} + +// legacyPassthroughEvents 是已知的遗留透传事件,允许不注册到 contractRegistry。 +// 新增的 runtime Event* 常量必须显式注册到 contractRegistry,否则 CI 失败。 +var legacyPassthroughEvents = map[string]struct{}{ + "thinking_delta": {}, + "plan_updated": {}, + "budget_checked": {}, + "budget_estimate_failed": {}, + "ledger_reconciled": {}, + "repository_context_unavailable": {}, + "resume_applied": {}, + "run_diff_summary": {}, +} + +// TestRegisteredEventTypesSorted 验证 RegisteredEventTypes 返回排序后的列表。 +func TestRegisteredEventTypesSorted(t *testing.T) { + types := RegisteredEventTypes() + if len(types) == 0 { + t.Fatal("RegisteredEventTypes returned empty slice") + } + for i := 1; i < len(types); i++ { + if types[i] < types[i-1] { + t.Fatalf("RegisteredEventTypes not sorted: %q < %q at index %d", types[i], types[i-1], i) + } + } +} + +// TestRequireConsumerKnownEvents 验证已知的 RequireConsumer=true 事件被正确注册。 +func TestRequireConsumerKnownEvents(t *testing.T) { + mustRequireConsumer := []EventType{ + EventUserMessage, + EventToolStart, + EventToolResult, + EventPermissionRequested, + EventCompactApplied, + EventTokenUsage, + EventHookStarted, + EventCheckpointCreated, + EventSubAgentStarted, + EventRuntimeSnapshotUpdated, + EventDecisionMade, + } + for _, eventType := range mustRequireConsumer { + if !RequireConsumer(eventType) { + t.Errorf("expected RequireConsumer(%q) = true, got false", eventType) + } + } +} + +// TestRequireConsumerUnregistered 验证未注册事件返回 false(允许透传)。 +func TestRequireConsumerUnregistered(t *testing.T) { + if RequireConsumer("nonexistent_event") { + t.Error("expected RequireConsumer for unregistered event to return false") + } +} + +// TestRequireConsumerPassthroughEvents 验证显式声明为透传安全的事件返回 false。 +func TestRequireConsumerPassthroughEvents(t *testing.T) { + passthroughEvents := []EventType{ + EventRunCanceled, + EventSkillActivated, + EventSkillDeactivated, + EventSkillMissing, + EventProgressEvaluated, + EventTodoSummaryInjected, + } + for _, eventType := range passthroughEvents { + if RequireConsumer(eventType) { + t.Errorf("expected RequireConsumer(%q) = false for passthrough event, got true", eventType) + } + } +} + +// TestIsRegisteredEventType 验证事件注册查询。 +func TestIsRegisteredEventType(t *testing.T) { + if !IsRegisteredEventType(EventUserMessage) { + t.Error("expected EventUserMessage to be registered") + } + if IsRegisteredEventType("totally_unknown_event") { + t.Error("expected unknown event to not be registered") + } +} + +// TestRuntimeEventContractConsistency 扫描 runtime 事件常量并与 contractRegistry 求差集。 +// 未注册且不在 legacyPassthroughEvents 中的事件会导致测试失败。 +func TestRuntimeEventContractConsistency(t *testing.T) { + runtimeEventValues := collectRuntimeEventConstants(t) + if len(runtimeEventValues) == 0 { + t.Fatal("no runtime Event* constants found in events.go / events_subagent.go") + } + + registeredTypes := make(map[EventType]struct{}, len(contractRegistry)) + for eventType := range contractRegistry { + registeredTypes[eventType] = struct{}{} + } + + var unregistered []string + for _, eventValue := range runtimeEventValues { + eventType := EventType(eventValue) + if _, registered := registeredTypes[eventType]; registered { + continue + } + if _, legacy := legacyPassthroughEvents[eventValue]; legacy { + t.Logf("runtime event %q in legacyPassthroughEvents allowlist (passthrough allowed)", eventValue) + continue + } + unregistered = append(unregistered, eventValue) + } + + if len(unregistered) > 0 { + sort.Strings(unregistered) + t.Fatalf( + "runtime events not registered in contractRegistry and not in legacyPassthroughEvents:\n %s\n\n"+ + "Fix: add to contractRegistry in runtime_contract.go with explicit RequireConsumer decision, "+ + "or add to legacyPassthroughEvents if passthrough is acceptable.", + strings.Join(unregistered, "\n "), + ) + } + + // 反向检查:contractRegistry 中 RequireConsumer=true 的事件是否都在 runtime 中定义 + runtimeEventSet := make(map[string]struct{}, len(runtimeEventValues)) + for _, v := range runtimeEventValues { + runtimeEventSet[v] = struct{}{} + } + var ghostEvents []string + for eventType, entry := range contractRegistry { + if !entry.RequireConsumer { + continue + } + if isTUIBridgeEvent(eventType) { + continue + } + if _, exists := runtimeEventSet[string(eventType)]; !exists { + ghostEvents = append(ghostEvents, string(eventType)) + } + } + + if len(ghostEvents) > 0 { + sort.Strings(ghostEvents) + t.Fatalf( + "contractRegistry events with RequireConsumer=true not found in runtime events.go:\n %s\n\n"+ + "Fix: remove from contractRegistry or add the event to runtime events.go.", + strings.Join(ghostEvents, "\n "), + ) + } +} + +// TestGatewayDecodeBranchConsistency 扫描 gateway_stream_client.go 的 decode 分支, +// 验证所有 decode 分支中处理的事件类型都在 contractRegistry 中注册。 +func TestGatewayDecodeBranchConsistency(t *testing.T) { + decodedConstNames := collectGatewayDecodeConstNames(t) + if len(decodedConstNames) == 0 { + t.Fatal("no decode branches found in restoreRuntimePayload") + } + + // 构建 contractRegistry 值到 EventType 的反向映射 + valueToEventType := make(map[string]EventType, len(contractRegistry)) + for eventType := range contractRegistry { + valueToEventType[string(eventType)] = eventType + } + + // 构建 contractRegistry 中所有已注册的事件值集合 + registeredValues := make(map[string]struct{}, len(contractRegistry)) + for eventType := range contractRegistry { + registeredValues[string(eventType)] = struct{}{} + } + + // TUI bridge 事件值 + bridgeValues := map[string]struct{}{ + "run_context": {}, + "tool_status": {}, + "usage": {}, + } + + for _, constName := range decodedConstNames { + // 如果是字符串值(如 "user_message"),直接检查 + if _, registered := registeredValues[constName]; registered { + continue + } + // 如果是 bridge 事件,跳过 + if _, isBridge := bridgeValues[constName]; isBridge { + continue + } + // 如果是常量名(如 "EventUserMessage"),尝试解析 + if resolvedValue, ok := resolveConstNameToValue(constName); ok { + if _, registered := registeredValues[resolvedValue]; registered { + continue + } + } + t.Errorf( + "gateway decode branch handles %q but it is not registered in contractRegistry; "+ + "add it to contractRegistry with RequireConsumer=true", + constName, + ) + } +} + +// TestRequireConsumerMustHaveDecodeBranch 验证 contractRegistry 中 RequireConsumer=true 的事件 +// 必须在 gateway_stream_client.go 中有对应的 decode 分支。 +// 这是 CI 防漏的关键测试:新增 RequireConsumer=true 事件但忘记添加 decode 分支时,此测试失败。 +func TestRequireConsumerMustHaveDecodeBranch(t *testing.T) { + decodedValues := collectGatewayDecodeConstNames(t) + decodedSet := make(map[string]struct{}, len(decodedValues)) + for _, v := range decodedValues { + decodedSet[v] = struct{}{} + } + + // bridge 事件值 + bridgeValues := map[string]struct{}{ + RuntimeEventRunContext: {}, + RuntimeEventToolStatus: {}, + RuntimeEventUsage: {}, + } + + var violations []string + for eventType, entry := range contractRegistry { + if !entry.RequireConsumer { + continue + } + value := string(eventType) + if _, decoded := decodedSet[value]; decoded { + continue + } + if _, isBridge := bridgeValues[value]; isBridge { + continue + } + violations = append(violations, value) + } + + if len(violations) > 0 { + sort.Strings(violations) + t.Fatalf( + "contractRegistry events with RequireConsumer=true missing gateway decode branch:\n %s\n\n"+ + "Fix: add a decode branch in restoreRuntimePayload (gateway_stream_client.go), or "+ + "set RequireConsumer=false in contractRegistry if passthrough is acceptable.", + strings.Join(violations, "\n "), + ) + } +} + +// TestRequireConsumerMustHaveTUIConsumer 验证 contractRegistry 中 RequireConsumer=true 的事件 +// 必须在 TUI update.go 的 runtimeEventHandlerRegistry 中有对应的 handler。 +// 这确保事件不会在 decode 后被 handleRuntimeEvent 静默丢弃。 +func TestRequireConsumerMustHaveTUIConsumer(t *testing.T) { + tuiHandlerEvents := collectTUIEventHandlerEvents(t) + if len(tuiHandlerEvents) == 0 { + t.Fatal("no events found in runtimeEventHandlerRegistry") + } + + tuiHandlerSet := make(map[string]struct{}, len(tuiHandlerEvents)) + for _, v := range tuiHandlerEvents { + tuiHandlerSet[v] = struct{}{} + } + + var violations []string + for eventType, entry := range contractRegistry { + if !entry.RequireConsumer { + continue + } + value := string(eventType) + if _, handled := tuiHandlerSet[value]; handled { + continue + } + // bridge 事件在 update.go 中也有 handler,但如果缺失也算违规 + violations = append(violations, value) + } + + if len(violations) > 0 { + sort.Strings(violations) + t.Fatalf( + "contractRegistry events with RequireConsumer=true missing TUI event handler:\n %s\n\n"+ + "Fix: add a handler in runtimeEventHandlerRegistry (internal/tui/core/app/update.go), or "+ + "set RequireConsumer=false in contractRegistry if TUI consumption is not required.", + strings.Join(violations, "\n "), + ) + } +} + +// isTUIBridgeEvent 判断事件是否为 TUI 侧特有的 bridge 事件(非 runtime 产生)。 +func isTUIBridgeEvent(eventType EventType) bool { + bridgeEvents := map[EventType]struct{}{ + EventType(RuntimeEventRunContext): {}, + EventType(RuntimeEventToolStatus): {}, + EventType(RuntimeEventUsage): {}, + EventRunCanceled: {}, + } + _, ok := bridgeEvents[eventType] + return ok +} + +// resolveConstNameToValue 尝试将常量名(如 "EventUserMessage")解析为字符串值(如 "user_message")。 +// 使用 contractRegistry 的键作为已知值映射。 +func resolveConstNameToValue(constName string) (string, bool) { + // 从 gateway_stream_client.go 中的 EventType(RuntimeEventXxx) 模式 + // 这些是 bridge 事件,值在 runtime_bridge.go 中定义 + bridgeConstMap := map[string]string{ + "RuntimeEventRunContext": RuntimeEventRunContext, + "RuntimeEventToolStatus": RuntimeEventToolStatus, + "RuntimeEventUsage": RuntimeEventUsage, + } + if value, ok := bridgeConstMap[constName]; ok { + return value, true + } + return "", false +} + +// collectTUIEventHandlerEvents 从 update.go 的 runtimeEventHandlerRegistry 中提取已注册的事件值。 +func collectTUIEventHandlerEvents(t *testing.T) []string { + t.Helper() + + projectRoot := findProjectRoot(t) + filePath := filepath.Join(projectRoot, "internal", "tui", "core", "app", "update.go") + + src, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("read %s: %v", filePath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, src, parser.ParseComments) + if err != nil { + t.Fatalf("parse %s: %v", filePath, err) + } + + // 从 runtime_contract.go 构建常量名→值映射 + constNameToValue := buildEventTypeConstMap(t, filepath.Join(projectRoot, "internal", "tui", "services", "runtime_contract.go")) + // bridge 常量 + bridgeConstMap := buildBridgeConstMap(t, filepath.Join(projectRoot, "internal", "tui", "services", "runtime_bridge.go")) + for k, v := range bridgeConstMap { + constNameToValue[k] = v + } + + var eventValues []string + ast.Inspect(file, func(n ast.Node) bool { + // 查找 runtimeEventHandlerRegistry 的 map literal + valueSpec, ok := n.(*ast.ValueSpec) + if !ok { + return true + } + // 查找 var runtimeEventHandlerRegistry = map[...]...{...} + isRegistry := false + for _, name := range valueSpec.Names { + if name.Name == "runtimeEventHandlerRegistry" { + isRegistry = true + break + } + } + if !isRegistry || len(valueSpec.Values) == 0 { + return true + } + + // 解析 composite literal 中的 key + for _, val := range valueSpec.Values { + compositeLit, ok := val.(*ast.CompositeLit) + if !ok { + continue + } + for _, elt := range compositeLit.Elts { + kvExpr, ok := elt.(*ast.KeyValueExpr) + if !ok { + continue + } + eventValue := extractEventValueFromExpr(kvExpr.Key, constNameToValue) + if eventValue != "" { + eventValues = append(eventValues, eventValue) + } + } + } + return true + }) + return eventValues +} + +// extractEventValueFromExpr 从 AST 表达式中提取事件字符串值。 +func extractEventValueFromExpr(expr ast.Expr, constNameToValue map[string]string) string { + switch v := expr.(type) { + case *ast.Ident: + // tuiservices.EventXxx + if value, ok := constNameToValue[v.Name]; ok { + return value + } + return v.Name + case *ast.SelectorExpr: + // tuiservices.EventXxx → 提取 EventXxx + if ident, ok := v.X.(*ast.Ident); ok && ident.Name == "tuiservices" { + if value, ok := constNameToValue[v.Sel.Name]; ok { + return value + } + return v.Sel.Name + } + case *ast.CallExpr: + // tuiservices.EventType(tuiservices.RuntimeEventXxx) → 提取 RuntimeEventXxx + if funIdent, ok := v.Fun.(*ast.Ident); ok && funIdent.Name == "EventType" { + if len(v.Args) > 0 { + return extractEventValueFromExpr(v.Args[0], constNameToValue) + } + } + // tuiservices.EventType(tuiservices.RuntimeEventXxx) via SelectorExpr + if sel, ok := v.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == "tuiservices" && sel.Sel.Name == "EventType" { + if len(v.Args) > 0 { + return extractEventValueFromExpr(v.Args[0], constNameToValue) + } + } + } + } + return "" +} + +// collectRuntimeEventConstants 从 runtime 事件源文件中提取所有 Event* 常量值。 +func collectRuntimeEventConstants(t *testing.T) []string { + t.Helper() + + projectRoot := findProjectRoot(t) + var allValues []string + for _, relPath := range runtimeContractEventSourceFiles { + filePath := filepath.Join(projectRoot, filepath.FromSlash(relPath)) + allValues = append(allValues, extractEventConstValues(t, filePath)...) + } + return allValues +} + +// extractEventConstValues 使用 AST 解析提取文件中 EventType 类型的常量值。 +func extractEventConstValues(t *testing.T, filePath string) []string { + t.Helper() + + src, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("read %s: %v", filePath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, src, parser.ParseComments) + if err != nil { + t.Fatalf("parse %s: %v", filePath, err) + } + + var values []string + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.CONST { + continue + } + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + // 检查类型是否为 EventType + if valueSpec.Type == nil { + continue + } + typeIdent, ok := valueSpec.Type.(*ast.Ident) + if !ok || typeIdent.Name != "EventType" { + continue + } + for i, name := range valueSpec.Names { + if !strings.HasPrefix(name.Name, "Event") { + continue + } + if i < len(valueSpec.Values) { + if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok { + // 去掉引号 + value := strings.Trim(basicLit.Value, "\"") + values = append(values, value) + } + } + } + } + } + return values +} + +// collectGatewayDecodeConstNames 从 gateway_stream_client.go 的 restoreRuntimePayload 中提取解码的事件类型值。 +// 对于常量引用(如 EventUserMessage),通过解析同包 const 声明解析为字符串值。 +func collectGatewayDecodeConstNames(t *testing.T) []string { + t.Helper() + + projectRoot := findProjectRoot(t) + + // 从 runtime_contract.go 构建常量名→值映射 + constNameToValue := buildEventTypeConstMap(t, filepath.Join(projectRoot, "internal", "tui", "services", "runtime_contract.go")) + // 从 runtime_bridge.go 构建 bridge 常量名→值映射 + bridgeConstMap := buildBridgeConstMap(t, filepath.Join(projectRoot, "internal", "tui", "services", "runtime_bridge.go")) + for k, v := range bridgeConstMap { + constNameToValue[k] = v + } + + filePath := filepath.Join(projectRoot, "internal", "tui", "services", "gateway_stream_client.go") + + src, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("read %s: %v", filePath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, src, parser.ParseComments) + if err != nil { + t.Fatalf("parse %s: %v", filePath, err) + } + + var eventValues []string + ast.Inspect(file, func(n ast.Node) bool { + caseClause, ok := n.(*ast.CaseClause) + if !ok { + return true + } + for _, expr := range caseClause.List { + switch v := expr.(type) { + case *ast.Ident: + // case EventXxx: + if strings.HasPrefix(v.Name, "Event") || strings.HasPrefix(v.Name, "RuntimeEvent") { + if value, ok := constNameToValue[v.Name]; ok { + eventValues = append(eventValues, value) + } else { + // 无法解析的常量名,保留原始名称用于诊断 + eventValues = append(eventValues, v.Name) + } + } + case *ast.CallExpr: + // case EventType("xxx") 或 case EventType(RuntimeEventXxx): + if funIdent, ok := v.Fun.(*ast.Ident); ok && funIdent.Name == "EventType" { + if len(v.Args) > 0 { + switch arg := v.Args[0].(type) { + case *ast.BasicLit: + // case EventType("xxx"): + value := strings.Trim(arg.Value, "\"") + eventValues = append(eventValues, value) + case *ast.Ident: + // case EventType(RuntimeEventXxx): + if value, ok := constNameToValue[arg.Name]; ok { + eventValues = append(eventValues, value) + } else { + eventValues = append(eventValues, arg.Name) + } + } + } + } + } + } + return true + }) + return eventValues +} + +// buildEventTypeConstMap 从 runtime_contract.go 中提取 EventType 常量名→值映射。 +func buildEventTypeConstMap(t *testing.T, filePath string) map[string]string { + t.Helper() + return extractConstStringMap(t, filePath, "EventType") +} + +// buildBridgeConstMap 从 runtime_bridge.go 中提取常量名→值映射(包括无类型常量)。 +func buildBridgeConstMap(t *testing.T, filePath string) map[string]string { + t.Helper() + + src, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("read %s: %v", filePath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, src, parser.ParseComments) + if err != nil { + t.Fatalf("parse %s: %v", filePath, err) + } + + result := make(map[string]string) + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.CONST { + continue + } + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + // 只提取以 RuntimeEvent 开头的常量 + for i, name := range valueSpec.Names { + if !strings.HasPrefix(name.Name, "RuntimeEvent") { + continue + } + if i < len(valueSpec.Values) { + if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok { + value := strings.Trim(basicLit.Value, "\"") + result[name.Name] = value + } + } + } + } + } + return result +} + +// extractConstStringMap 从指定文件中提取指定类型的 const 字符串映射。 +func extractConstStringMap(t *testing.T, filePath string, typeName string) map[string]string { + t.Helper() + + src, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("read %s: %v", filePath, err) + } + + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, filePath, src, parser.ParseComments) + if err != nil { + t.Fatalf("parse %s: %v", filePath, err) + } + + result := make(map[string]string) + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.CONST { + continue + } + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + if valueSpec.Type == nil { + continue + } + typeIdent, ok := valueSpec.Type.(*ast.Ident) + if !ok || typeIdent.Name != typeName { + continue + } + for i, name := range valueSpec.Names { + if i < len(valueSpec.Values) { + if basicLit, ok := valueSpec.Values[i].(*ast.BasicLit); ok { + value := strings.Trim(basicLit.Value, "\"") + result[name.Name] = value + } + } + } + } + } + return result +} + +// findProjectRoot 向上查找 go.mod 所在目录。 +func findProjectRoot(t *testing.T) string { + t.Helper() + + dir, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("could not find project root (go.mod not found)") + } + dir = parent + } +} diff --git a/web/src/api/gateway.test.ts b/web/src/api/gateway.test.ts index 336b3f403..ae1f50a29 100644 --- a/web/src/api/gateway.test.ts +++ b/web/src/api/gateway.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, vi, beforeEach } from 'vitest' +import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest' import { GatewayAPI } from './gateway' import { Method } from './protocol' @@ -13,6 +13,10 @@ describe('GatewayAPI', () => { api = new GatewayAPI(ws) }) + afterEach(() => { + vi.unstubAllGlobals() + }) + it('maps authenticate and run methods', async () => { await api.authenticate('tok') await api.run({ input_text: 'hello' }) @@ -21,6 +25,14 @@ describe('GatewayAPI', () => { expect(call).toHaveBeenNthCalledWith(2, Method.Run, { input_text: 'hello' }) }) + it('maps createSession method', async () => { + await api.createSession() + await api.createSession('s1') + + expect(call).toHaveBeenNthCalledWith(1, Method.CreateSession, {}) + expect(call).toHaveBeenNthCalledWith(2, Method.CreateSession, { session_id: 's1' }) + }) + it('maps optional session_id in listModels', async () => { await api.listModels() await api.listModels('s1') @@ -53,10 +65,93 @@ describe('GatewayAPI', () => { it('maps permission and user question resolution', async () => { await api.resolvePermission({ request_id: 'r1', decision: 'allow_once' }) + await api.approvePlan({ session_id: 's1', plan_id: 'p1', revision: 2 }) await api.resolveUserQuestion({ request_id: 'q1', status: 'answered', message: 'ok' }) expect(call).toHaveBeenNthCalledWith(1, Method.ResolvePermission, { request_id: 'r1', decision: 'allow_once' }) - expect(call).toHaveBeenNthCalledWith(2, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) + expect(call).toHaveBeenNthCalledWith(2, Method.ApprovePlan, { session_id: 's1', plan_id: 'p1', revision: 2 }) + expect(call).toHaveBeenNthCalledWith(3, Method.UserQuestionAnswer, { request_id: 'q1', status: 'answered', message: 'ok' }) + }) + + it('uploads session assets with bearer auth, workspace header, and multipart body', async () => { + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + json: () => Promise.resolve({ session_id: 's1', asset_id: 'asset-1', mime_type: 'image/png', size: 3 }), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, 'http://localhost:1455/', ' token-1 ') + + const file = new File(['abc'], 'a.png', { type: 'image/png' }) + const result = await api.uploadSessionAsset('s1', file, 'workspace-b') + + expect(result.asset_id).toBe('asset-1') + expect(fetchMock).toHaveBeenCalledWith('http://localhost:1455/api/session-assets', expect.objectContaining({ + method: 'POST', + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + })) + const init = fetchMock.mock.calls[0][1] as RequestInit + expect(init.body).toBeInstanceOf(FormData) + expect((init.body as FormData).get('session_id')).toBe('s1') + expect((init.body as FormData).get('file')).toBe(file) + }) + + it('fetches session asset blobs with bearer auth and workspace header', async () => { + const blob = new Blob(['img'], { type: 'image/png' }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(blob), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '/gateway', 'token-1') + + await expect(api.fetchSessionAsset('s 1', 'asset/1', 'workspace-b')).resolves.toBe(blob) + expect(fetchMock).toHaveBeenCalledWith('/gateway/api/session-assets/s%201/asset%2F1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + }) + }) + + it('deletes session assets with bearer auth and workspace header', async () => { + const fetchMock = vi.fn().mockResolvedValue({ ok: true }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '/gateway', 'token-1') + + await api.deleteSessionAsset('s 1', 'asset/1', 'workspace-b') + + expect(fetchMock).toHaveBeenCalledWith('/gateway/api/session-assets/s%201/asset%2F1', { + method: 'DELETE', + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-b' }, + }) + }) + + it('uses switched workspace as session asset HTTP fallback', async () => { + call.mockResolvedValueOnce({ type: 'ack', payload: { workspace_hash: 'workspace-c' } }) + const fetchMock = vi.fn().mockResolvedValue({ + ok: true, + blob: () => Promise.resolve(new Blob(['img'])), + }) + vi.stubGlobal('fetch', fetchMock) + api = new GatewayAPI(ws, '', 'token-1') + + await api.switchWorkspace('workspace-c') + await api.fetchSessionAsset('s1', 'asset-1') + await api.deleteSessionAsset('s1', 'asset-1') + + expect(fetchMock).toHaveBeenNthCalledWith(1, '/api/session-assets/s1/asset-1', { + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-c' }, + }) + expect(fetchMock).toHaveBeenNthCalledWith(2, '/api/session-assets/s1/asset-1', { + method: 'DELETE', + headers: { Authorization: 'Bearer token-1', 'X-NeoCode-Workspace-Hash': 'workspace-c' }, + }) + }) + + it('surfaces session asset HTTP errors', async () => { + vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ + ok: false, + status: 415, + json: () => Promise.resolve({ error: 'unsupported image type' }), + })) + await expect(api.uploadSessionAsset('s1', new File(['x'], 'x.txt'))).rejects.toThrow('unsupported image type') }) }) diff --git a/web/src/api/gateway.ts b/web/src/api/gateway.ts index acbff8808..060acd9b6 100644 --- a/web/src/api/gateway.ts +++ b/web/src/api/gateway.ts @@ -5,6 +5,8 @@ import { type AuthenticateParams, type BindStreamParams, type RunParams, + type CreateSessionParams, + type CreateSessionResult, type CancelParams, type LoadSessionParams, type ListSessionTodosParams, @@ -18,6 +20,8 @@ import { type CheckpointDiffParams, type CheckpointDiffResult, type ResolvePermissionParams, + type ApprovePlanParams, + type ApprovePlanResult, type ResolveUserQuestionParams, type Session, type RunAckResult, @@ -71,14 +75,20 @@ import { type RenameWorkspaceResult, type DeleteWorkspaceParams, type DeleteWorkspaceResult, + type SessionAssetUploadResult, } from './protocol' /** Gateway 业务 API 客户端,基于 WebSocket 全双工通道 */ export class GatewayAPI { private ws: WSClient + private baseURL: string + private token: string + private currentWorkspaceHash = '' - constructor(ws: WSClient) { + constructor(ws: WSClient, baseURL = '', token = '') { this.ws = ws + this.baseURL = baseURL.replace(/\/+$/, '') + this.token = token.trim() } /** 认证,返回 ack 结果 */ @@ -91,11 +101,59 @@ export class GatewayAPI { return this.ws.call(Method.BindStream, params) } + /** 显式创建一个会话,供发送图片前建立 asset 归属 */ + async createSession(sessionId?: string) { + const params: CreateSessionParams = sessionId ? { session_id: sessionId } : {} + return this.ws.call(Method.CreateSession, params) + } + /** 发起一次 run,返回 ack 含 session_id 和 run_id */ async run(params: RunParams) { return this.ws.call(Method.Run, params) } + /** 上传会话图片附件,返回可在 input_parts 中引用的 asset_id */ + async uploadSessionAsset(sessionId: string, file: File, workspaceHash = '') { + const form = new FormData() + form.append('session_id', sessionId) + form.append('file', file) + const res = await fetch(`${this.baseURL}/api/session-assets`, { + method: 'POST', + headers: this.httpHeaders(workspaceHash), + body: form, + }) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Upload failed')) + } + return res.json() as Promise + } + + /** 读取会话图片附件 Blob,用于历史消息缩略图 */ + async fetchSessionAsset(sessionId: string, assetId: string, workspaceHash = '') { + const res = await fetch( + `${this.baseURL}/api/session-assets/${encodeURIComponent(sessionId)}/${encodeURIComponent(assetId)}`, + { headers: this.httpHeaders(workspaceHash) }, + ) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Asset fetch failed')) + } + return res.blob() + } + + /** 删除会话图片附件,用于取消发送或删除已上传引用后的服务端清理 */ + async deleteSessionAsset(sessionId: string, assetId: string, workspaceHash = '') { + const res = await fetch( + `${this.baseURL}/api/session-assets/${encodeURIComponent(sessionId)}/${encodeURIComponent(assetId)}`, + { + method: 'DELETE', + headers: this.httpHeaders(workspaceHash), + }, + ) + if (!res.ok) { + throw new Error(await readHTTPError(res, 'Asset delete failed')) + } + } + /** 取消运行,返回取消结果 */ async cancel(params: CancelParams) { return this.ws.call(Method.Cancel, params) @@ -144,6 +202,10 @@ export class GatewayAPI { return this.ws.call(Method.ResolvePermission, params) } + async approvePlan(params: ApprovePlanParams) { + return this.ws.call(Method.ApprovePlan, params) + } + /** 提交 ask_user 回答 */ async resolveUserQuestion(params: ResolveUserQuestionParams) { return this.ws.call(Method.UserQuestionAnswer, params) @@ -284,7 +346,9 @@ export class GatewayAPI { /** 切换工作区 */ async switchWorkspace(workspaceHash: string) { - return this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + const result = await this.ws.call(Method.SwitchWorkspace, { workspace_hash: workspaceHash } satisfies SwitchWorkspaceParams) + this.currentWorkspaceHash = workspaceHash.trim() + return result } /** 重命名工作区 */ @@ -296,4 +360,21 @@ export class GatewayAPI { async deleteWorkspace(workspaceHash: string, removeData?: boolean) { return this.ws.call(Method.DeleteWorkspace, { workspace_hash: workspaceHash, remove_data: removeData } satisfies DeleteWorkspaceParams) } + + getCurrentWorkspaceHash() { + return this.currentWorkspaceHash + } + + private httpHeaders(workspaceHash = '') { + const headers: Record = {} + if (this.token) headers.Authorization = `Bearer ${this.token}` + const resolvedWorkspaceHash = workspaceHash.trim() || this.currentWorkspaceHash + if (resolvedWorkspaceHash) headers['X-NeoCode-Workspace-Hash'] = resolvedWorkspaceHash + return Object.keys(headers).length > 0 ? headers : undefined + } +} + +async function readHTTPError(res: Response, fallback: string) { + const data = await res.json().catch(() => null) as { error?: string } | null + return data?.error || `${fallback} (HTTP ${res.status})` } diff --git a/web/src/api/protocol.ts b/web/src/api/protocol.ts index 69c4164e9..84ea510e6 100644 --- a/web/src/api/protocol.ts +++ b/web/src/api/protocol.ts @@ -11,6 +11,7 @@ export const Method = { Ping: "gateway.ping", BindStream: "gateway.bindStream", Run: "gateway.run", + CreateSession: "gateway.createSession", Cancel: "gateway.cancel", Compact: "gateway.compact", ListSessions: "gateway.listSessions", @@ -21,6 +22,7 @@ export const Method = { UndoRestore: "checkpoint.undoRestore", CheckpointDiff: "checkpoint.diff", ResolvePermission: "gateway.resolvePermission", + ApprovePlan: "gateway.approvePlan", UserQuestionAnswer: "gateway.userQuestionAnswer", ExecuteSystemTool: "gateway.executeSystemTool", ActivateSessionSkill: "gateway.activateSessionSkill", @@ -62,6 +64,7 @@ export const FrameType = { // 帧动作 export const FrameAction = { Run: "run", + ApprovePlan: "approve_plan", ListProviders: "list_providers", CreateCustomProvider: "create_custom_provider", DeleteCustomProvider: "delete_custom_provider", @@ -77,6 +80,7 @@ export const EventType = { UserMessage: "user_message", AgentChunk: "agent_chunk", AgentDone: "agent_done", + PlanUpdated: "plan_updated", ToolStart: "tool_start", ToolResult: "tool_result", ToolDiff: "tool_diff", @@ -101,6 +105,7 @@ export const EventType = { BudgetEstimateFailed: "budget_estimate_failed", LedgerReconciled: "ledger_reconciled", StopReasonDecided: "stop_reason_decided", + RunError: "run_error", InputNormalized: "input_normalized", SkillActivated: "skill_activated", SkillDeactivated: "skill_deactivated", @@ -134,7 +139,20 @@ export const StopReason = { FatalError: "fatal_error", BudgetExceeded: "budget_exceeded", MaxTurnExceeded: "max_turn_exceeded", + VerificationFailed: "verification_failed", Accepted: "accepted", + EmptyResponse: "empty_response", + AcceptContinue: "accept_continue", + AcceptContinueExhausted: "accept_continue_exhausted", + TodoNotConverged: "todo_not_converged", + TodoWaitingExternal: "todo_waiting_external", + RepeatCycle: "repeat_cycle", + MaxTurnExceededWithUnconvergedTodos: "max_turn_exceeded_with_unconverged_todos", + MaxTurnExceededWithFailedVerification: "max_turn_exceeded_with_failed_verification", + VerificationConfigMissing: "verification_config_missing", + VerificationExecutionDenied: "verification_execution_denied", + VerificationExecutionError: "verification_execution_error", + RequiredTodoFailed: "required_todo_failed", RetryExhausted: "retry_exhausted", } as const; @@ -217,9 +235,15 @@ export interface RunParams { export interface RunInputPart { type: string; text?: string; - media?: { uri: string; mime_type: string; file_name?: string }; + media?: { uri?: string; asset_id?: string; mime_type: string; file_name?: string }; } +export interface CreateSessionParams { + session_id?: string; +} + +export type CreateSessionResult = RPCResult<{ session_id: string }>; + /** gateway.cancel 参数 */ export interface CancelParams { session_id?: string; @@ -264,6 +288,12 @@ export interface ResolvePermissionParams { decision: string; } +export interface ApprovePlanParams { + session_id: string; + plan_id: string; + revision: number; +} + /** gateway.userQuestionAnswer 参数 */ export interface ResolveUserQuestionParams { request_id: string; @@ -284,11 +314,19 @@ export interface SessionSummary { export interface SessionMessage { role: string; content: string; + parts?: RunInputPart[]; tool_calls?: ToolCall[]; tool_call_id?: string; is_error?: boolean; } +export interface SessionAssetUploadResult { + session_id: string; + asset_id: string; + mime_type: string; + size: number; +} + /** 工具调用 */ export interface ToolCall { id: string; @@ -296,6 +334,47 @@ export interface ToolCall { arguments: string; } +export interface PlanTodoItem { + id: string; + content: string; + status?: string; + required?: boolean; + artifacts?: string[]; + failure_reason?: string; + blocked_reason?: string; + revision?: number; +} + +export interface PlanSpec { + goal: string; + steps?: string[]; + constraints?: string[]; + todos?: PlanTodoItem[]; + open_questions?: string[]; +} + +export interface PlanSummaryView { + goal: string; + key_steps?: string[]; + constraints?: string[]; + active_todo_ids?: string[]; +} + +export interface PlanArtifact { + id: string; + revision: number; + status: string; + spec: PlanSpec; + summary: PlanSummaryView; + created_at: string; + updated_at: string; +} + +export interface PlanUpdatedPayload { + current_plan?: PlanArtifact; + display_text?: string; +} + /** 会话详情 */ export interface Session { id: string; @@ -306,6 +385,7 @@ export interface Session { provider?: string; model?: string; agent_mode?: string; + current_plan?: PlanArtifact; messages?: SessionMessage[]; } @@ -357,6 +437,8 @@ export type ListSessionsResult = RPCResult<{ sessions: SessionSummary[] }>; /** gateway.cancel 响应 */ export type CancelResult = RPCResult<{ canceled: boolean; run_id: string }>; +export type ApprovePlanResult = RPCResult<{ plan_id: string; revision: number; status: string }>; + export interface TodoViewItem { id: string; content: string; @@ -622,6 +704,7 @@ export interface ModelEntry { id: string; name: string; provider: string; + capability_hints?: ProviderModelCapabilityHints; } /** gateway.listModels 响应 */ diff --git a/web/src/components/chat/ChatInput.test.tsx b/web/src/components/chat/ChatInput.test.tsx index 53829db08..fcbef263e 100644 --- a/web/src/components/chat/ChatInput.test.tsx +++ b/web/src/components/chat/ChatInput.test.tsx @@ -6,10 +6,14 @@ import { useComposerStore } from '@/stores/useComposerStore' import { useSessionStore } from '@/stores/useSessionStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' import { useGatewayStore } from '@/stores/useGatewayStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' const mockGatewayAPI = { listAvailableSkills: vi.fn(), listModels: vi.fn(), + createSession: vi.fn(), + uploadSessionAsset: vi.fn(), + deleteSessionAsset: vi.fn(), run: vi.fn(), bindStream: vi.fn(), cancel: vi.fn(), @@ -68,9 +72,24 @@ describe('ChatInput', () => { selected_model_id: '', }, }) - - useComposerStore.setState({ composerText: '' }) + mockGatewayAPI.createSession.mockResolvedValue({ payload: { session_id: 'session-created' } }) + mockGatewayAPI.uploadSessionAsset.mockResolvedValue({ + session_id: 'session-created', + asset_id: 'asset-1', + mime_type: 'image/png', + size: 3, + }) + mockGatewayAPI.deleteSessionAsset.mockResolvedValue({}) + mockGatewayAPI.run.mockResolvedValue({ session_id: 'session-created', run_id: 'run-1' }) + mockGatewayAPI.bindStream.mockResolvedValue({}) + if (typeof URL.createObjectURL !== 'function') { + Object.defineProperty(URL, 'createObjectURL', { configurable: true, value: vi.fn() }) + } + vi.spyOn(URL, 'createObjectURL').mockReturnValue('blob:preview-1') + + useComposerStore.setState({ composerText: '', attachments: [] }) useSessionStore.setState({ currentSessionId: '' } as never) + useWorkspaceStore.setState({ currentWorkspaceHash: 'workspace-b' } as never) useGatewayStore.setState({ currentRunId: '' } as never) useRuntimeInsightStore.getState().reset() useChatStore.setState({ @@ -157,12 +176,158 @@ describe('ChatInput', () => { }) }) - it('does not render the unimplemented attachment and mention buttons', () => { + it('renders the image attachment picker but keeps mention button absent', () => { render() - expect(screen.queryByTitle('附件文件')).not.toBeInTheDocument() + expect(screen.getByRole('button', { name: /添加图片/ })).toBeInTheDocument() expect(screen.queryByTitle('引用上下文')).not.toBeInTheDocument() }) + + it('uploads selected image and sends image-only input parts after creating a session', async () => { + render() + + const file = new File(['img'], 'a.png', { type: 'image/png' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + await waitFor(() => { + expect(screen.getByAltText('a.png')).toBeInTheDocument() + }) + + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.createSession).toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-created', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith({ + session_id: 'session-created', + input_parts: [ + { type: 'image', media: { asset_id: 'asset-1', mime_type: 'image/png', file_name: 'a.png' } }, + ], + mode: 'build', + }) + }) + expect(mockGatewayAPI.createSession.mock.invocationCallOrder[0]).toBeLessThan( + mockGatewayAPI.bindStream.mock.invocationCallOrder[0], + ) + expect(mockGatewayAPI.bindStream.mock.invocationCallOrder[0]).toBeLessThan( + mockGatewayAPI.uploadSessionAsset.mock.invocationCallOrder[0], + ) + expect(mockGatewayAPI.uploadSessionAsset.mock.invocationCallOrder[0]).toBeLessThan( + mockGatewayAPI.run.mock.invocationCallOrder[0], + ) + + expect(useChatStore.getState().messages[0]).toMatchObject({ + role: 'user', + attachments: [{ assetId: 'asset-1', previewUrl: 'blob:preview-1', workspaceHash: 'workspace-b' }], + }) + }) + + it('blocks image selection when the selected model explicitly rejects images', async () => { + mockGatewayAPI.listModels.mockResolvedValueOnce({ + payload: { + models: [{ + id: 'text-model', + name: 'Text Model', + provider: 'openai', + capability_hints: { image_input: 'unsupported' }, + }], + selected_provider_id: 'openai', + selected_model_id: 'text-model', + }, + }) + render() + + await waitFor(() => { + expect(screen.getByRole('button', { name: /添加图片/ })).toBeDisabled() + }) + const file = new File(['img'], 'a.png', { type: 'image/png' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + + await waitFor(() => { + expect(useComposerStore.getState().attachments).toHaveLength(0) + }) + }) + + it('blocks sending existing image attachments when the selected model rejects images', async () => { + mockGatewayAPI.listModels.mockResolvedValueOnce({ + payload: { + models: [{ + id: 'text-model', + name: 'Text Model', + provider: 'openai', + capability_hints: { image_input: 'unsupported' }, + }], + selected_provider_id: 'openai', + selected_model_id: 'text-model', + }, + }) + render() + + await waitFor(() => { + expect(screen.getByRole('button', { name: /添加图片/ })).toBeDisabled() + }) + useComposerStore.getState().addAttachmentFiles([new File(['img'], 'a.png', { type: 'image/png' })]) + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.createSession).not.toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).not.toHaveBeenCalled() + expect(mockGatewayAPI.run).not.toHaveBeenCalled() + }) + }) + + it('deletes uploaded session assets when run fails', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ + session_id: 'session-1', + asset_id: 'asset-failed', + mime_type: 'image/png', + size: 3, + }) + mockGatewayAPI.run.mockRejectedValueOnce(new Error('run failed')) + render() + + const file = new File(['img'], 'failed.png', { type: 'image/png' }) + const input = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(input, { target: { files: [file] } }) + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.deleteSessionAsset).toHaveBeenCalledWith('session-1', 'asset-failed', 'workspace-b') + }) + expect(useChatStore.getState().messages).toHaveLength(0) + }) + + it('treats slash text as a normal message when an image is attached', async () => { + useSessionStore.setState({ currentSessionId: 'session-1' } as never) + mockGatewayAPI.uploadSessionAsset.mockResolvedValueOnce({ + session_id: 'session-1', + asset_id: 'asset-2', + mime_type: 'image/png', + size: 3, + }) + render() + + const file = new File(['img'], 'slash.png', { type: 'image/png' }) + const fileInput = document.querySelector('input[type="file"]') as HTMLInputElement + fireEvent.change(fileInput, { target: { files: [file] } }) + fireEvent.change(screen.getByRole('textbox'), { target: { value: '/memo' } }) + fireEvent.keyDown(screen.getByRole('textbox'), { key: 'Enter' }) + + await waitFor(() => { + expect(mockGatewayAPI.executeSystemTool).not.toHaveBeenCalled() + expect(mockGatewayAPI.uploadSessionAsset).toHaveBeenCalledWith('session-1', file, 'workspace-b') + expect(mockGatewayAPI.run).toHaveBeenCalledWith(expect.objectContaining({ + session_id: 'session-1', + input_parts: [ + { type: 'text', text: '/memo' }, + { type: 'image', media: { asset_id: 'asset-2', mime_type: 'image/png', file_name: 'slash.png' } }, + ], + })) + }) + }) it('blocks normal sends while compaction is running', async () => { useChatStore.getState().startCompacting('manual', 'Compacting context...') render() diff --git a/web/src/components/chat/ChatInput.tsx b/web/src/components/chat/ChatInput.tsx index 6291a702b..568de7810 100644 --- a/web/src/components/chat/ChatInput.tsx +++ b/web/src/components/chat/ChatInput.tsx @@ -3,10 +3,17 @@ import { useChatStore, createUserMessage } from '@/stores/useChatStore' import { useGatewayStore } from '@/stores/useGatewayStore' import { useSessionStore, isValidSessionId } from '@/stores/useSessionStore' import { useUIStore } from '@/stores/useUIStore' -import { useComposerStore } from '@/stores/useComposerStore' +import { + acceptedImageMimeTypes, + maxComposerAttachmentBytes, + useComposerStore, + type ComposerAttachment, +} from '@/stores/useComposerStore' import { useRuntimeInsightStore } from '@/stores/useRuntimeInsightStore' +import { useWorkspaceStore } from '@/stores/useWorkspaceStore' import { formatTokenCount } from '@/utils/format' import { useGatewayAPI } from '@/context/RuntimeProvider' +import { type ModelEntry } from '@/api/protocol' import { builtinSlashCommands, matchSlashCommands, @@ -19,7 +26,7 @@ import { import SlashCommandMenu from './SlashCommandMenu' import SkillPicker from './SkillPicker' import ModelSelector from './ModelSelector' -import { Send, Square } from 'lucide-react' +import { ImagePlus, Loader2, Send, Square, X } from 'lucide-react' const slashMenuAnchorStyle: React.CSSProperties = { position: 'absolute', @@ -30,6 +37,12 @@ const slashMenuAnchorStyle: React.CSSProperties = { const budgetWarningThresholdRatio = 0.9 const budgetDangerThresholdRatio = 0.95 +const unsupportedImageInputMessage = '当前模型不支持图片输入,请切换支持图片的模型' + +type UploadedSessionAsset = { + attachment: ComposerAttachment + meta: { asset_id: string; mime_type: string; size?: number } +} /** 将网关返回的技能列表转换成输入框使用的 slash 命令结构。 */ function buildSkillSlashCommands( @@ -123,14 +136,22 @@ function resolveBudgetRingState( export default function ChatInput() { const gatewayAPI = useGatewayAPI() const text = useComposerStore((state) => state.composerText) + const attachments = useComposerStore((state) => state.attachments) const setText = useComposerStore((state) => state.setComposerText) + const addAttachmentFiles = useComposerStore((state) => state.addAttachmentFiles) + const removeAttachment = useComposerStore((state) => state.removeAttachment) + const clearAttachments = useComposerStore((state) => state.clearAttachments) + const setAttachmentStatus = useComposerStore((state) => state.setAttachmentStatus) const [rows, setRows] = useState(1) + const [dragActive, setDragActive] = useState(false) const textareaRef = useRef(null) + const fileInputRef = useRef(null) const runCancelledRef = useRef(false) const composingRef = useRef(false) const isGenerating = useChatStore((state) => state.isGenerating) const isCompacting = useChatStore((state) => state.isCompacting) const addMessage = useChatStore((state) => state.addMessage) + const removeMessage = useChatStore((state) => state.removeMessage) const addSystemMessage = useChatStore((state) => state.addSystemMessage) const setGenerating = useChatStore((state) => state.setGenerating) const sessionId = useSessionStore((state) => state.currentSessionId) @@ -138,6 +159,9 @@ export default function ChatInput() { const setAgentMode = useChatStore((state) => state.setAgentMode) const permissionMode = useChatStore((state) => state.permissionMode) const setPermissionMode = useChatStore((state) => state.setPermissionMode) + const currentWorkspaceHash = useWorkspaceStore((state) => state.currentWorkspaceHash) + const providerChangeTick = useGatewayStore((state) => state.providerChangeTick) + const [currentImageInput, setCurrentImageInput] = useState('') const [showSlashMenu, setShowSlashMenu] = useState(false) const [selectedIndex, setSelectedIndex] = useState(0) @@ -164,6 +188,26 @@ export default function ChatInput() { } }, [text, gatewayAPI, sessionId]) + useEffect(() => { + if (!gatewayAPI) return + let cancelled = false + gatewayAPI.listModels(sessionId || undefined).then((result) => { + if (cancelled) return + const payload = result.payload + const selected = resolveSelectedModelEntry( + payload?.models || [], + payload?.selected_provider_id || '', + payload?.selected_model_id || '', + ) + setCurrentImageInput(selected?.capability_hints?.image_input || '') + }).catch(() => { + if (!cancelled) setCurrentImageInput('') + }) + return () => { + cancelled = true + } + }, [gatewayAPI, sessionId, providerChangeTick]) + useEffect(() => { if (!isSlashCommand(text)) { setMatchedCommands([]) @@ -302,7 +346,12 @@ export default function ChatInput() { async function handleSubmit() { const input = text.trim() - if (!input) return + const pendingAttachments = attachments + if (!input && pendingAttachments.length === 0) return + let submittedMessageId = '' + let targetSessionId = sessionId + let workspaceHash = currentWorkspaceHash.trim() + let uploaded: UploadedSessionAsset[] = [] if (isCompacting) { useUIStore.getState().showToast('Context compaction is still running', 'info') @@ -314,30 +363,68 @@ export default function ChatInput() { return } - if (isSlashCommand(input)) { + if (pendingAttachments.length === 0 && isSlashCommand(input)) { setText('') setShowSlashMenu(false) const handled = await executeSlashCommand(input) if (handled) return } - setText('') - const userMsg = createUserMessage(input) - addMessage(userMsg) - useRuntimeInsightStore.getState().setTodoSnapshot({ - items: [], - summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, - }) - setGenerating(true) - runCancelledRef.current = false + if (pendingAttachments.length > 0 && currentImageInput === 'unsupported') { + useUIStore.getState().showToast(unsupportedImageInputMessage, 'error') + return + } try { if (!gatewayAPI) return - const isNewSession = !isValidSessionId(sessionId) + if (!isValidSessionId(targetSessionId)) { + const created = await gatewayAPI.createSession() + targetSessionId = created.payload?.session_id || '' + if (!isValidSessionId(targetSessionId)) throw new Error('Create session failed') + useSessionStore.getState().setCurrentSessionId(targetSessionId) + await gatewayAPI.bindStream({ session_id: targetSessionId, channel: 'all' }).catch(() => {}) + } + + workspaceHash = currentWorkspaceHash.trim() + for (const attachment of pendingAttachments) { + setAttachmentStatus(attachment.id, 'uploading') + try { + const meta = await gatewayAPI.uploadSessionAsset(targetSessionId, attachment.file, workspaceHash) + setAttachmentStatus(attachment.id, 'uploaded') + uploaded.push({ attachment, meta }) + } catch (err) { + const message = err instanceof Error ? err.message : 'Upload failed' + setAttachmentStatus(attachment.id, 'error', message) + throw err + } + } + + const inputParts = buildRunInputParts(input, uploaded) + const userMsg = createUserMessage(input, uploaded.map(({ attachment, meta }) => ({ + id: attachment.id, + sessionId: targetSessionId, + workspaceHash, + assetId: meta.asset_id, + mimeType: meta.mime_type, + name: attachment.file.name, + size: meta.size, + previewUrl: attachment.previewUrl, + }))) + + setText('') + clearAttachments(false) + addMessage(userMsg) + submittedMessageId = userMsg.id + useRuntimeInsightStore.getState().setTodoSnapshot({ + items: [], + summary: { total: 0, required_total: 0, required_completed: 0, required_failed: 0, required_open: 0 }, + }) + setGenerating(true) + runCancelledRef.current = false + const ack = await gatewayAPI.run({ - session_id: isNewSession ? undefined : sessionId, - new_session: isNewSession ? true : undefined, - input_text: input, + session_id: targetSessionId, + input_parts: inputParts, mode: agentMode, }) if (!runCancelledRef.current) { @@ -350,11 +437,16 @@ export default function ChatInput() { } } } catch (err) { + if (gatewayAPI && uploaded.length > 0 && isValidSessionId(targetSessionId)) { + await cleanupUploadedSessionAssets(gatewayAPI, targetSessionId, workspaceHash, uploaded) + } if (!runCancelledRef.current) { + if (submittedMessageId) { + removeMessage(submittedMessageId) + } setGenerating(false) - useChatStore.getState().removeMessage(userMsg.id) console.error('Run failed:', err) - useUIStore.getState().showToast('Failed to send message', 'error') + useUIStore.getState().showToast(err instanceof Error ? err.message : 'Failed to send message', 'error') } } } @@ -421,6 +513,37 @@ export default function ChatInput() { void executeSlashCommand(cmd.usage) } + function handleFilesSelected(files: FileList | File[]) { + if (currentImageInput === 'unsupported') { + useUIStore.getState().showToast(unsupportedImageInputMessage, 'error') + return + } + const accepted: File[] = [] + for (const file of Array.from(files)) { + if (!acceptedImageMimeTypes.includes(file.type as any)) { + useUIStore.getState().showToast('Only PNG, JPEG, and WebP images are supported', 'error') + continue + } + if (file.size <= 0) { + useUIStore.getState().showToast('Cannot upload an empty file', 'error') + continue + } + if (file.size > maxComposerAttachmentBytes) { + useUIStore.getState().showToast('Image exceeds the 20 MiB limit', 'error') + continue + } + accepted.push(file) + } + if (accepted.length > 0) addAttachmentFiles(accepted) + } + + function handleDrop(e: React.DragEvent) { + e.preventDefault() + setDragActive(false) + if (controlsLocked) return + handleFilesSelected(e.dataTransfer.files) + } + async function handleCancel() { runCancelledRef.current = true const runId = useGatewayStore.getState().currentRunId @@ -439,7 +562,7 @@ export default function ChatInput() { } } - const isEmpty = !text.trim() + const isEmpty = !text.trim() && attachments.length === 0 const controlsLocked = isGenerating || isCompacting return ( @@ -460,7 +583,16 @@ export default function ChatInput() { /> )} -
+
{ + e.preventDefault() + if (!controlsLocked) setDragActive(true) + }} + onDragLeave={() => setDragActive(false)} + onDrop={handleDrop} + > +