diff --git a/.kiro/steering/routing-and-orchestration.md b/.kiro/steering/routing-and-orchestration.md index 0266d3d1..4d93b3d9 100644 --- a/.kiro/steering/routing-and-orchestration.md +++ b/.kiro/steering/routing-and-orchestration.md @@ -59,6 +59,7 @@ The current selector language includes these core-owned behaviors: - parallel groups (`!`) that race multiple B-legs, - per-leg `[handicap=N]` start delays in parallel groups, - global and per-leaf `{ttft_timeout=N}` / `[ttft_timeout=N]` budgets, +- per-leaf query generation params that override matching per-request body/call options, - model aliases that rewrite full selector strings before parsing. Mixing incompatible selector forms must fail early. In particular, parallel `!` groups cannot be mixed with `^`, weights, or `[first]` in the same arm. diff --git a/README.md b/README.md index 6f9a51ca..9f8f9996 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,9 @@ go run ./cmd/lipstd --config ./config/config.yaml ## Configuration and operations - **Config** - Runtime config is typed and loaded from YAML. [`config/config.yaml`](config/config.yaml) documents access/auth templates, server timeouts, logging, diagnostics, observability, routing, continuity, and provider rows. [`config/config.multi-instance.example.yaml`](config/config.multi-instance.example.yaml) shows multiple backend instances of the same adapter. -- **Routing** - Default selectors come from `routing.default_route` or the first enabled backend plus registry default model ids. `model_aliases` rewrite full selector strings before parsing. Route selectors support ordered failover, weights, first-request annotations, parallel `!` races, per-leg `[handicap=N]`, and global/per-leg TTFT budgets. +- **Routing** - Default selectors come from `routing.default_route` or the first enabled backend plus registry default model ids. `model_aliases` rewrite full selector strings before parsing. Route selectors support ordered failover, weights, first-request annotations, parallel `!` races, per-leg `[handicap=N]`, global/per-leg TTFT budgets, and per-leaf query generation parameters. Route query parameters such as `?reasoning_effort=xhigh` are explicit routing directives: when present, they override matching per-request body/canonical generation options; absent parameters leave request values unchanged. - **Continuity** - `continuity.store: memory` is the default. `continuity.store: sqlite` with `continuity.sqlite_path` persists A-leg rows and attempt lineage through [`internal/core/continuity/sqlitestore`](internal/core/continuity/sqlitestore). In-memory `ttl` and `max_legs` tuning does not apply to SQLite. -- **Security** - Multi-user or non-loopback deployments need explicit auth/access posture. Local API keys must be at least 16 Unicode code points after trimming. Diagnostics, pprof, metrics, model-catalog diagnostics, and secure-session summaries require a shared secret when exposed beyond loopback. +- **Security** - Multi-user or non-loopback deployments need explicit auth/access posture. Local API keys must be at least 16 Unicode code points after trimming. Diagnostics, pprof, metrics, model-catalog diagnostics, and secure-session summaries require a shared secret when exposed beyond loopback. On Unix, OpenAI Codex `auth.json` and managed-OAuth account files must be `0600` (group/other-readable files are now rejected at load); symlinked managed-OAuth account files are skipped. See [`docs/openai-codex-backend.md`](docs/openai-codex-backend.md#token-file-permissions). - **Observability** - Optional Prometheus metrics and OpenTelemetry tracing are configured under `observability`. Access logs use bounded-cardinality route groups by default; raw paths are opt-in. - **HTTP clients** - The shared upstream client honors `HTTP_PROXY` / `HTTPS_PROXY` by default. Set `http_client.trust_environment_proxy: false` when process environment is not trusted. - **Resource bounds** - `lipapi.Call.Validate`, `lipapi.Collect` limits, pending wire event caps, and B2BUA store caps protect memory and request size boundaries. diff --git a/config/config.yaml b/config/config.yaml index 371e4a04..0edd8123 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -304,6 +304,8 @@ plugins: # managed_oauth_storage_path: var/openai_codex_oauth_accounts # managed_oauth_selection_strategy: first-available # first-available | round-robin | session-affinity # managed_oauth_allow_auth_json_fallback: true + # transport: https # default; websocket/auto require experimental_websocket: true + # experimental_websocket: false # gpt55_downgrade_disabled: false # OPENAI_CODEX_ACCESS_TOKEN / OPENAI_CODEX_API_KEY (+ _N variants) env vars - id: ollama diff --git a/config/examples/opencode-codex.yaml b/config/examples/opencode-codex.yaml new file mode 100644 index 00000000..2e6f23c2 --- /dev/null +++ b/config/examples/opencode-codex.yaml @@ -0,0 +1,112 @@ +# OPENCODE + OPENAI-CODEX LIVE TEST +# Routes OpenCode (Responses API) to the ChatGPT Codex backend using the Codex CLI auth file. +# Credentials: auto-discovered from ~/.codex/auth.json (tokens.access_token / refresh_token / account_id). +# Override with OPENAI_CODEX_ACCESS_TOKEN env var or config.access_token if needed. +# +# Start: go run ./cmd/lipstd serve --config ./config/examples/opencode-codex.yaml +# Point OpenCode at: http://127.0.0.1:8080/v1 (Responses API, @ai-sdk/openai), model gpt-5.5 +server: + address: "127.0.0.1:8080" + +routing: + max_attempts: 3 + default_route: "openai-codex:gpt-5.5" + +continuity: + in_memory: true + store: memory + +logging: + level: info + format: text + +diagnostics: + enabled: true + health_path: "/healthz" + attempts_path: "/admin/attempts" + inventory_path: "/debug/inventory" + route_trace_path: "/debug/route_trace" + +hooks: + tool_reactor_error_policy: fail_open + +plugins: + frontends: + - id: openai-responses + enabled: true + config: {} + - id: openai-legacy + enabled: true + config: {} + - id: anthropic + enabled: true + config: {} + - id: gemini + enabled: true + config: {} + backends: + - id: openai-responses + enabled: false + config: {} + - id: openai-legacy + enabled: false + config: {} + - id: anthropic + enabled: false + config: {} + - id: gemini + enabled: false + config: {} + - id: bedrock + enabled: false + config: {} + - id: acp + enabled: false + config: {} + - id: openrouter + enabled: false + config: {} + - id: nvidia + enabled: false + config: {} + - id: opencode-go + enabled: false + config: {} + - id: opencode-zen + enabled: false + config: {} + - id: ollama + enabled: false + config: {} + - id: ollama-cloud + enabled: false + config: {} + - id: llamacpp + enabled: false + config: {} + - id: lmstudio + enabled: false + config: {} + - id: vllm + enabled: false + config: {} + - id: openai-codex + enabled: true + config: + # base_url defaults to https://chatgpt.com/backend-api/codex — leave unset for live ChatGPT. + # access_token left empty so the connector auto-discovers ~/.codex/auth.json. + # account_id is read from auth.json tokens.account_id; override here only if needed. + default_reasoning_effort: "medium" + features: + - id: submit-noop + enabled: true + config: {} + - id: parts-noop + enabled: true + config: {} + - id: tool-reactor-noop + enabled: true + config: {} + - id: codex-client-compat + enabled: true + config: {} diff --git a/docs/openai-codex-backend.md b/docs/openai-codex-backend.md index dc3311c9..e64d92aa 100644 --- a/docs/openai-codex-backend.md +++ b/docs/openai-codex-backend.md @@ -1,6 +1,6 @@ # OpenAI Codex backend -The `openai-codex` backend connects to the ChatGPT Codex Responses API (`https://chatgpt.com/backend-api/codex/responses`). Route selectors use the `openai-codex` prefix, for example `openai-codex:gpt-5.3-codex`. +The `openai-codex` backend connects to the ChatGPT Codex Responses API (`https://chatgpt.com/backend-api/codex/responses`). Route selectors use the `openai-codex` prefix, for example `openai-codex:gpt-5.5`. ## Enable @@ -21,6 +21,12 @@ plugins: - Environment: `OPENAI_CODEX_ACCESS_TOKEN`, then numbered `_2`, `_3`, …; falls back to `OPENAI_CODEX_API_KEY` (+ `_N` variants) when access-token vars are unset. - When neither `access_token` nor `auth_json_path` is set, the connector reads `~/.codex/auth.json` if present (Codex CLI default). +## Token file permissions + +On Unix, the `auth.json` file and managed-OAuth account files in `managed_oauth_storage_path` must be owner-only (`0600`). Files readable or writable by group or other are rejected at load time with an error mentioning `group/other accessible`; fix with `chmod 600 `. This mirrors the Codex CLI `auth.json` guard and fails closed on multi-user hosts. On Windows (ACL-based permissions, no meaningful Unix mode bits) this check is a no-op. + +Symlinked account files inside the managed-OAuth storage directory are skipped during discovery, so a symlink planted in that directory cannot cause the proxy to read a target outside it. Use real files for managed accounts. + ## Optional settings | Field | Purpose | @@ -32,6 +38,9 @@ plugins: | `oauth_client_id` | OAuth client id (OpenAI Codex CLI default) | | `account_id` | `ChatGPT-Account-Id` header | | `default_reasoning_effort` | Default reasoning effort for requests | +| `transport` | `https` (default), `auto`, or `websocket` | +| `experimental_websocket` | Required opt-in for `transport: auto` or `transport: websocket` | +| `websocket_fallback_cooldown_seconds` | Auto-mode cooldown after a pre-output WebSocket failure (default 300) | | `models` | Static model inventory (inline or file), same shape as other backends | | `managed_oauth_enabled` | Load OAuth accounts from JSON files in `managed_oauth_storage_path` | | `managed_oauth_storage_path` | Directory of `*.json` account files | @@ -46,7 +55,11 @@ plugins: | `gpt55_downgrade_target_model` | Target model for free-plan downgrade (default `gpt-5.4`) | | `plan_type_hint` | Optional plan hint for proactive downgrade tests/local overrides | -Without `models`, the connector exposes a built-in Codex model list. +Without `models`, the connector exposes a built-in Codex model list: `gpt-5.5`, `gpt-5.4`, `gpt-5.4-mini`, and `gpt-5.3-codex-spark`. + +## Transport + +The default transport is HTTPS/SSE. WebSocket support is experimental and only enabled when `experimental_websocket: true` is set. With that opt-in, `transport: auto` tries `wss://chatgpt.com/backend-api/codex/responses` first and falls back to HTTPS/SSE only if WebSocket fails before the first canonical event. After that first event, stream errors are surfaced and not retried. Use `transport: websocket` to fail instead of falling back during debugging. After a pre-output WebSocket failure, auto mode skips WebSocket for `websocket_fallback_cooldown_seconds` to avoid repeated retry latency. ## Client compatibility (OpenCode / Pi / Droid / Hermes) @@ -66,7 +79,84 @@ The request-part hook detects client markers from extensions, headers, prompts, ```yaml routes: - default: "openai-codex:gpt-5.3-codex" + default: "openai-codex:gpt-5.5" ``` Bracket parameters such as `?reasoning_effort=high` are supported in route selectors. + +## Per-request routing + +A client can override the configured default route per request by putting a full route +selector in the request body `model` field, with optional URI parameters: + +```json +{ "model": "openai-codex:gpt-5.5?reasoning_effort=low", "input": "ping" } +``` + +The `openai-codex:` prefix selects the backend, the model name selects any model (the +builtin inventory lists `gpt-5.5`, `gpt-5.4`, `gpt-5.4-mini`, and `gpt-5.3-codex-spark`; +arbitrary model strings can still be routed even if not listed), and the `reasoning_effort` +URI parameter is converted into the canonical call options and then into the Codex payload +`reasoning.effort` field. An explicit `X-LIP-Route` header, when present, takes precedence +over the body `model`. A bare model name without a backend prefix still falls back to the +configured default route. + +URI parameters are explicit routing directives and **override** any corresponding value +set elsewhere: a `?reasoning_effort=xhigh` on the selector wins over a `reasoning_effort` +field in the request body and over the backend's `default_reasoning_effort`. A parameter +absent from the selector leaves the other value in effect. The same override rule applies +to `temperature`, `top_p`, `max_output_tokens`, and `parallel_tool_calls` when present in +the selector. + +## Unsupported generation parameters + +The Codex Responses API does not support `temperature`, `top_p`, or `max_output_tokens`. +Plain calls that set any of these fail at payload-build time with an explicit error. +The `openai_codex.ignore_unsupported_gen_params` canonical-call extension (bool, `true`) +opts in to dropping them instead — the `codex-client-compat` feature sets this for +detected compatibility clients (OpenCode, pi, Factory Droid, Hermes) so optional tuning +params are not forwarded upstream and do not fail the request. `reasoning_effort` and +`parallel_tool_calls` are honored. + +## Model name normalization + +Clients that use a `provider/model` namespace (for example OpenCode's `openai/gpt-5.4-mini`) +have the leading `openai/` prefix stripped before the model reaches the Codex upstream, which +rejects org-prefixed model names. A bare model name such as `gpt-5.4-mini` is sent unchanged. + +## System messages + +The Codex Responses API rejects `system`-role items in `input` ("System messages are not +allowed"). System content must be carried in the `instructions` field. The connector folds +system-role messages from the conversation into `instructions` (deduplicated against explicit +instructions, including the `codex-client-compat` bridge) and omits them from `input`, so +clients that send a system prompt (for example OpenCode) interoperate without a capability +mismatch. + +## Tool schemas + +The Codex Responses API requires function-tool parameter schemas to be +"strict-compatible" when sent with `strict:true`: every object must declare +`additionalProperties:false` and list all of its properties in `required`. +Clients that emit looser schemas (for example OpenCode's `apply_patch`, which +omits `additionalProperties`) would otherwise be rejected with +`invalid_function_parameters`. The connector inspects each tool schema and +sends `strict:false` for any schema that is not strict-compatible, while +keeping `strict:true` for strict-compatible and parameterless schemas. This is +a safe relaxation — it only disables strict validation and never causes an +upstream rejection. The Hermes compatibility bridge keeps its existing +`tool_strict:false` behavior (all tools relaxed). + +## Tool call history + +When a client sends a prior assistant tool call and its result back (for example +OpenCode following up after executing a tool), the Chat Completions frontend +encodes the assistant tool call as a `PartJSON` item in the Chat Completions +shape (`type:"function"` with a nested `function:{name,arguments}` object and +`id` as the call id). The connector translates that into a Codex Responses +`function_call` input item (using the Chat Completions `id` as the `call_id`) +and the matching `tool`-role result into a `function_call_output` item with the +same `call_id`, so the upstream sees a correctly linked call/output pair. The +`codex-client-compat` bridge recognizes Chat Completions-style tool calls when +matching tool results, so results that belong to a known call are preserved +rather than treated as orphaned. diff --git a/go.mod b/go.mod index 44b41c66..77dbaaa2 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/bedrock v1.64.0 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.54.0 github.com/aws/smithy-go v1.27.2 + github.com/gorilla/websocket v1.5.3 github.com/jellydator/ttlcache/v3 v3.4.1 github.com/openai/openai-go/v3 v3.41.0 github.com/prometheus/client_golang v1.23.2 @@ -63,7 +64,6 @@ require ( github.com/google/s2a-go v0.1.8 // indirect github.com/google/uuid v1.6.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect - github.com/gorilla/websocket v1.5.3 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect diff --git a/internal/core/diag/debug_summary.go b/internal/core/diag/debug_summary.go new file mode 100644 index 00000000..fce9ce53 --- /dev/null +++ b/internal/core/diag/debug_summary.go @@ -0,0 +1,54 @@ +package diag + +import ( + "log/slog" + "os" + "sort" + "strconv" + "strings" + "sync" +) + +const envDebugTurns = "LIP_CODEX_DEBUG_TURNS" + +var debugTurnsEnabled = sync.OnceValue(func() bool { + return strings.TrimSpace(os.Getenv(envDebugTurns)) != "" +}) + +// DebugTurnsEnabled reports whether verbose per-turn diagnostics are enabled for +// this process. The environment is read once so debug wrappers agree on a single +// process-lifetime gate. +func DebugTurnsEnabled() bool { + return debugTurnsEnabled() +} + +// LoggerOrDefault returns log when present, otherwise slog.Default(). +func LoggerOrDefault(log *slog.Logger) *slog.Logger { + if log != nil { + return log + } + return slog.Default() +} + +// StableCounts formats count maps as sorted "key=value" strings for stable logs. +func StableCounts(counts map[string]int) []string { + keys := make([]string, 0, len(counts)) + for k := range counts { + keys = append(keys, k) + } + sort.Strings(keys) + out := make([]string, 0, len(keys)) + for _, k := range keys { + out = append(out, k+"="+strconv.Itoa(counts[k])) + } + return out +} + +// AppendLimited appends a trimmed non-empty value until max entries are present. +func AppendLimited(values []string, value string, max int) []string { + value = strings.TrimSpace(value) + if value == "" || len(values) >= max { + return values + } + return append(values, value) +} diff --git a/internal/core/routing/parser_test.go b/internal/core/routing/parser_test.go index 4453a293..5a9f5a7d 100644 --- a/internal/core/routing/parser_test.go +++ b/internal/core/routing/parser_test.go @@ -110,7 +110,7 @@ func TestParseFirstSingleArm(t *testing.T) { // Task 14.5: parity with composite routing examples (failover |, weighted ^, [first], [weight=], per-leg query). func TestParseParity_pythonLIPCompositeSelector(t *testing.T) { t.Parallel() - s := "[first]openai-codex:gpt-5.3-codex?reasoning_effort=high^[weight=4]openai-codex:gpt-5.3-codex?reasoning_effort=low|[weight=2]openai-codex:gpt-5.3-codex?reasoning_effort=medium" + s := "[first]openai-codex:gpt-5.3-codex-spark?reasoning_effort=high^[weight=4]openai-codex:gpt-5.3-codex-spark?reasoning_effort=low|[weight=2]openai-codex:gpt-5.3-codex-spark?reasoning_effort=medium" sel, err := Parse(s) if err != nil { t.Fatal(err) @@ -127,7 +127,7 @@ func TestParseParity_pythonLIPCompositeSelector(t *testing.T) { if !b0.IsFirst || b0.Weight != 1 { t.Fatalf("branch0: IsFirst=%v Weight=%d", b0.IsFirst, b0.Weight) } - if b0.Target.Backend != "openai-codex" || b0.Target.Model != "gpt-5.3-codex" { + if b0.Target.Backend != "openai-codex" || b0.Target.Model != "gpt-5.3-codex-spark" { t.Fatalf("branch0 target: %#v", b0.Target) } if b0.Target.Params.Get("reasoning_effort") != "high" { diff --git a/internal/core/routing/routeprefix.go b/internal/core/routing/routeprefix.go new file mode 100644 index 00000000..9c816c0e --- /dev/null +++ b/internal/core/routing/routeprefix.go @@ -0,0 +1,30 @@ +package routing + +import ( + "slices" + "strings" +) + +// FilterRoutePrefixes trims, drops invalid (empty, colon- or slash-bearing), +// dedups, and sorts backend route-selector prefixes. Shared by runtime bundle +// composition and frontend PrefixSet construction so the validation rule lives +// in one place. A prefix is the ":" segment of a route selector; it must +// not itself contain ":" (which would make it a full selector) or "/" (which +// collides with provider-namespace model syntax). +func FilterRoutePrefixes(prefixes []string) []string { + seen := make(map[string]struct{}, len(prefixes)) + out := make([]string, 0, len(prefixes)) + for _, prefix := range prefixes { + prefix = strings.TrimSpace(prefix) + if prefix == "" || strings.Contains(prefix, ":") || strings.Contains(prefix, "/") { + continue + } + if _, dup := seen[prefix]; dup { + continue + } + seen[prefix] = struct{}{} + out = append(out, prefix) + } + slices.Sort(out) + return out +} diff --git a/internal/core/routing/routeprefix_test.go b/internal/core/routing/routeprefix_test.go new file mode 100644 index 00000000..bc651981 --- /dev/null +++ b/internal/core/routing/routeprefix_test.go @@ -0,0 +1,35 @@ +package routing + +import ( + "slices" + "testing" +) + +func TestFilterRoutePrefixes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input []string + want []string + }{ + {name: "empty input", input: nil, want: []string{}}, + {name: "trims whitespace", input: []string{" openai-codex ", "anthropic"}, want: []string{"anthropic", "openai-codex"}}, + {name: "drops empty", input: []string{"", "openai-codex", " "}, want: []string{"openai-codex"}}, + {name: "drops colon-bearing prefix", input: []string{"openai-codex", "foo:bar"}, want: []string{"openai-codex"}}, + {name: "drops slash-bearing prefix", input: []string{"openai-codex", "foo/bar"}, want: []string{"openai-codex"}}, + {name: "dedups", input: []string{"openai-codex", "anthropic", "openai-codex"}, want: []string{"anthropic", "openai-codex"}}, + {name: "sorted output", input: []string{"ollama", "anthropic", "gemini"}, want: []string{"anthropic", "gemini", "ollama"}}, + {name: "all invalid", input: []string{"", ":", "/", " "}, want: []string{}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := FilterRoutePrefixes(tt.input) + if !slices.Equal(got, tt.want) { + t.Errorf("FilterRoutePrefixes(%v) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/core/runtime/executor_open_attempt.go b/internal/core/runtime/executor_open_attempt.go index fc8068a1..92a6f8e3 100644 --- a/internal/core/runtime/executor_open_attempt.go +++ b/internal/core/runtime/executor_open_attempt.go @@ -517,6 +517,7 @@ func (e *Executor) openPlannedCandidate( slog.String("operation", string(openCall.Invocation.Operation)), slog.String("client_delivery_mode", string(openCall.Invocation.DeliveryMode)), slog.String("upstream_transport_mode", string(openCall.Invocation.TransportMode)), + slog.String("reasoning_effort", openCall.Options.ReasoningEffort), slog.Int64("open_duration_ms", time.Since(openStart).Milliseconds()), ) e.logInterleavedRouteSelected(p.ctx, p.traceID, bleg.BLegID, c) diff --git a/internal/core/runtime/executor_test.go b/internal/core/runtime/executor_test.go index 5041afe8..75e1d6b7 100644 --- a/internal/core/runtime/executor_test.go +++ b/internal/core/runtime/executor_test.go @@ -672,7 +672,7 @@ func TestExecutor_routeQueryMergesIntoGenerationOptions(t *testing.T) { } } -func TestExecutor_routeQueryDoesNotOverrideExplicitCallOptions(t *testing.T) { +func TestExecutor_routeQueryOverridesExplicitCallOptions(t *testing.T) { t.Parallel() st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) if err != nil { @@ -716,8 +716,8 @@ func TestExecutor_routeQueryDoesNotOverrideExplicitCallOptions(t *testing.T) { if err != nil { t.Fatal(err) } - if captured.Temperature == nil || *captured.Temperature != 0.11 { - t.Fatalf("explicit call temperature must win over route, got %#v", captured.Temperature) + if captured.Temperature == nil || *captured.Temperature != 0.99 { + t.Fatalf("route temperature must override explicit call option, got %#v", captured.Temperature) } } diff --git a/internal/infra/runtimebundle/build.go b/internal/infra/runtimebundle/build.go index af05fa20..cedf193a 100644 --- a/internal/infra/runtimebundle/build.go +++ b/internal/infra/runtimebundle/build.go @@ -126,7 +126,7 @@ func Build(cfg *config.Config, bus *hooks.Bus, log *slog.Logger, opts *BuildOpti if startedCatalog != nil { closers = append(closers, startedCatalog.closers...) } - backends, inventories, err := buildBackends(cfg, reg, upstream, backendDeps) + backends, inventories, routePrefixes, err := buildBackends(cfg, reg, upstream, backendDeps) if err != nil { if derr := disposeClosers(closers); derr != nil { return nil, errors.Join(err, derr) @@ -286,6 +286,7 @@ func Build(cfg *config.Config, bus *hooks.Bus, log *slog.Logger, opts *BuildOpti Store: store, Closers: closers, UpstreamHTTP: upstream, + RoutePrefixes: routePrefixes, PluginRegistry: reg, EffectiveDefaultRoute: effectiveRoute, Metrics: bundle, @@ -423,9 +424,10 @@ func buildBackends( reg *pluginreg.Registry, upstream *http.Client, backendDeps pluginreg.BackendFactoryDeps, -) (map[string]execbackend.Backend, []modelregistry.BackendInventory, error) { +) (map[string]execbackend.Backend, []modelregistry.BackendInventory, []string, error) { backends := make(map[string]execbackend.Backend, len(cfg.Plugins.Backends)) inventories := make([]modelregistry.BackendInventory, 0, len(cfg.Plugins.Backends)) + rawPrefixes := make([]string, 0, len(cfg.Plugins.Backends)) modelInventoryFetchTimeout := cfg.ModelInventory.FetchTimeoutDuration() for _, p := range cfg.Plugins.Backends { if !p.Enabled { @@ -435,9 +437,10 @@ func buildBackends( iid := p.InstanceID() be, err := reg.BuildBackend(fid, p.Config, upstream, backendDeps) if err != nil { - return nil, nil, fmt.Errorf("backend instance %s (factory %s): %w", iid, fid, err) + return nil, nil, nil, fmt.Errorf("backend instance %s (factory %s): %w", iid, fid, err) } backends[iid] = be + rawPrefixes = append(rawPrefixes, be.BackendPrefixes...) inventories = append(inventories, modelregistry.BackendInventory{ BackendID: iid, Kind: fid, @@ -446,7 +449,8 @@ func buildBackends( FetchTimeout: modelInventoryFetchTimeout, }) } - return backends, inventories, nil + routePrefixes := routing.FilterRoutePrefixes(rawPrefixes) + return backends, inventories, routePrefixes, nil } func resolveRouting(cfg *config.Config, wireModel config.WireModelForBackend) (string, string, *routing.AliasResolver, error) { diff --git a/internal/infra/runtimebundle/build_test.go b/internal/infra/runtimebundle/build_test.go index c6c475c9..4f38baf4 100644 --- a/internal/infra/runtimebundle/build_test.go +++ b/internal/infra/runtimebundle/build_test.go @@ -4,6 +4,7 @@ import ( "context" "io" "net/http" + "slices" "strings" "testing" @@ -176,6 +177,31 @@ func TestBuild_setsEffectiveDefaultRoute_defaultWireModel(t *testing.T) { } } +func TestBuild_derivesRoutePrefixesFromEnabledBackends(t *testing.T) { + t.Parallel() + + reg := pluginreg.NewRegistry() + if err := pluginreg.InstallStandardBackendsOn(reg, pluginreg.UpstreamAPIKeys{}); err != nil { + t.Fatal(err) + } + cfg := &config.Config{ + Routing: config.RoutingConfig{MaxAttempts: 3}, + Plugins: config.PluginsConfig{ + Backends: []config.PluginConfig{{ID: "local-stub", Enabled: true}}, + }, + Continuity: config.ContinuityConfig{InMemory: true}, + } + b, err := runtimebundle.Build(cfg, hooks.New(hooks.Config{}), testkit.DiscardLogger(), &runtimebundle.BuildOptions{ + PluginRegistry: reg, + }) + if err != nil { + t.Fatal(err) + } + if !slices.Contains(b.RoutePrefixes, "local-stub") { + t.Fatalf("RoutePrefixes = %v, want local-stub", b.RoutePrefixes) + } +} + func TestBuild_respectsWireModelInBuildOptions(t *testing.T) { t.Parallel() cfg := &config.Config{ diff --git a/internal/infra/runtimebundle/built.go b/internal/infra/runtimebundle/built.go index 0c363569..8006a8b7 100644 --- a/internal/infra/runtimebundle/built.go +++ b/internal/infra/runtimebundle/built.go @@ -27,6 +27,8 @@ type Built struct { // UpstreamHTTP is the shared outbound HTTP client passed to backends that need upstream HTTP. // Successful [Build] always sets this (explicit [BuildOptions.HTTPClient] or the default from httpclient). UpstreamHTTP *http.Client + // RoutePrefixes are backend route-selector prefixes accepted from frontend protocol model fields. + RoutePrefixes []string // PluginRegistry is the registry used to construct backends and must be used when mounting frontends // or composing features. [Build] sets this from [BuildOptions.PluginRegistry]. PluginRegistry *pluginreg.Registry diff --git a/internal/infra/runtimebundle/opencode_zen_registry_live_test.go b/internal/infra/runtimebundle/opencode_zen_registry_live_test.go index 5ebc04a2..f7dd90f3 100644 --- a/internal/infra/runtimebundle/opencode_zen_registry_live_test.go +++ b/internal/infra/runtimebundle/opencode_zen_registry_live_test.go @@ -3,6 +3,7 @@ package runtimebundle_test import ( + "strings" "testing" "github.com/matdev83/go-llm-interactive-proxy/internal/core/config" @@ -57,26 +58,27 @@ func TestOpenCodeZenLive_modelsAreVendorEnrichedInCentralRegistry(t *testing.T) } defer closeRuntimeBuilt(t, built) - want := map[string]string{ - "deepseek/deepseek-v4-flash-free": "deepseek-v4-flash-free", - "xiaomi/mimo-v2.5-free": "mimo-v2.5-free", - "alibaba/qwen3.6-plus-free": "qwen3.6-plus-free", - "minimax/minimax-m3-free": "minimax-m3-free", + rows := built.ModelRegistry.All() + if len(rows) == 0 { + t.Fatal("central registry has no live opencode-zen models") } - for canonicalID, nativeID := range want { - rows, ok := built.ModelRegistry.Lookup(canonicalID) - if !ok { - t.Fatalf("central registry missing %q", canonicalID) + enriched := 0 + for _, row := range rows { + if row.BackendID != "opencode-zen" || row.Kind != "opencode-zen" { + t.Fatalf("central registry row = %+v, want opencode-zen backend/kind", row) } - found := false - for _, row := range rows { - if row.BackendID == "opencode-zen" && row.Kind == "opencode-zen" && row.NativeID == nativeID { - found = true - break - } + if row.NativeID == "" { + t.Fatalf("central registry row has empty native id: %+v", row) } - if !found { - t.Fatalf("central registry rows for %q = %+v, want opencode-zen native %q", canonicalID, rows, nativeID) + vendor, suffix, ok := strings.Cut(row.CanonicalID, "/") + if !ok || vendor == "" || suffix == "" { + t.Fatalf("central registry row has non-vendor canonical id: %+v", row) } + if vendor != "unknown" { + enriched++ + } + } + if enriched == 0 { + t.Fatalf("central registry has no vendor-enriched opencode-zen rows: %+v", rows) } } diff --git a/internal/pluginreg/backends_install_test.go b/internal/pluginreg/backends_install_test.go index 0a16beeb..c449a8bd 100644 --- a/internal/pluginreg/backends_install_test.go +++ b/internal/pluginreg/backends_install_test.go @@ -10,13 +10,13 @@ import ( func TestPrefixedModelIDsFromYAML_stripsNativePrefixAndFallsBackToCanonicalTail(t *testing.T) { t.Parallel() got, err := prefixedModelIDsFromYAML("openai-codex", modelInventoryYAML{Items: []modelInventoryItemYAML{ - {NativeID: "openai-codex/gpt-5.3-codex"}, + {NativeID: "openai-codex/gpt-5.3-codex-spark"}, {CanonicalID: "openai-codex/gpt-5.4"}, }}) if err != nil { t.Fatal(err) } - want := []prefixedModelYAML{{RawID: "gpt-5.3-codex"}, {RawID: "gpt-5.4"}} + want := []prefixedModelYAML{{RawID: "gpt-5.3-codex-spark"}, {RawID: "gpt-5.4"}} if !reflect.DeepEqual(got, want) { t.Fatalf("models = %#v, want %#v", got, want) } diff --git a/internal/pluginreg/backends_openaicodex.go b/internal/pluginreg/backends_openaicodex.go index 93e99dd1..68d996b8 100644 --- a/internal/pluginreg/backends_openaicodex.go +++ b/internal/pluginreg/backends_openaicodex.go @@ -34,6 +34,9 @@ type openAICodexBackendYAML struct { GPT55DowngradeSourceModel string `yaml:"gpt55_downgrade_source_model"` GPT55DowngradeTargetModel string `yaml:"gpt55_downgrade_target_model"` PlanTypeHint string `yaml:"plan_type_hint"` + Transport string `yaml:"transport"` + ExperimentalWebSocket bool `yaml:"experimental_websocket"` + WebSocketFallbackCooldownSeconds int `yaml:"websocket_fallback_cooldown_seconds"` } func backendOpenAICodex(n yaml.Node, upstream *http.Client, keys UpstreamAPIKeys) (execbackend.Backend, error) { @@ -78,6 +81,11 @@ func backendOpenAICodex(n yaml.Node, upstream *http.Client, keys UpstreamAPIKeys cfg.GPT55DowngradeSourceModel = strings.TrimSpace(y.GPT55DowngradeSourceModel) cfg.GPT55DowngradeTargetModel = strings.TrimSpace(y.GPT55DowngradeTargetModel) cfg.PlanTypeHint = strings.TrimSpace(y.PlanTypeHint) + cfg.Transport = strings.TrimSpace(y.Transport) + cfg.ExperimentalWebSocket = y.ExperimentalWebSocket + if y.WebSocketFallbackCooldownSeconds > 0 { + cfg.WebSocketFallbackCooldown = time.Duration(y.WebSocketFallbackCooldownSeconds) * time.Second + } return applyConfiguredModelInventory(openaicodex.New(cfg), y.Models) } diff --git a/internal/pluginreg/backends_openaicodex_test.go b/internal/pluginreg/backends_openaicodex_test.go index baf5c487..8e029a72 100644 --- a/internal/pluginreg/backends_openaicodex_test.go +++ b/internal/pluginreg/backends_openaicodex_test.go @@ -85,7 +85,7 @@ func TestOpenAICodexBackendFactory_buildsFromYAMLAndHitsRefEmulator(t *testing.T }}, } es, err := be.Open(context.Background(), call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -121,8 +121,8 @@ func TestOpenAICodexBackendFactory_configuredModelsFlowToInventory(t *testing.T) access_token: sk-codex models: items: - - canonical_id: openai-codex/gpt-5.3-codex - native_id: gpt-5.3-codex + - canonical_id: openai-codex/gpt-5.3-codex-spark + native_id: gpt-5.3-codex-spark - canonical_id: openai-codex/gpt-5.4 native_id: gpt-5.4 ` @@ -137,7 +137,7 @@ models: if err != nil { t.Fatal(err) } - if got := nativeIDs(snap.Models); !slices.Equal(got, []string{"gpt-5.3-codex", "gpt-5.4"}) { + if got := nativeIDs(snap.Models); !slices.Equal(got, []string{"gpt-5.3-codex-spark", "gpt-5.4"}) { t.Fatalf("native IDs = %#v", got) } } @@ -170,7 +170,7 @@ func TestOpenAICodexBackendFactory_authJSONPath(t *testing.T) { } _, err = be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatal(err) } @@ -204,7 +204,7 @@ func TestOpenAICodexBackendFactory_apiKeysFirstKeyUsed(t *testing.T) { } _, err = be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatal(err) } @@ -234,7 +234,7 @@ func TestOpenAICodexBackendFactory_credentialsAPIKey(t *testing.T) { } _, err = be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatal(err) } @@ -264,7 +264,7 @@ func TestOpenAICodexBackendFactory_envFallbackWhenYAMLHasNoKeys(t *testing.T) { } _, err = be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatal(err) } @@ -294,7 +294,7 @@ func TestOpenAICodexBackendFactory_ignoresDefaultTemperatureYAML(t *testing.T) { } es, err := be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatalf("default_temperature must not be plumbed or rejected: %v", err) } @@ -305,6 +305,63 @@ func TestOpenAICodexBackendFactory_ignoresDefaultTemperatureYAML(t *testing.T) { } } +func TestOpenAICodexBackendFactory_transportHTTPSFromYAML(t *testing.T) { + t.Parallel() + + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + reg := NewRegistry() + if err := InstallStandardBackendsOn(reg, UpstreamAPIKeys{}); err != nil { + t.Fatal(err) + } + var cfg yaml.Node + yamlText := "base_url: " + ts.URL + "/backend-api/codex\naccess_token: sk-codex\ntransport: https\nwebsocket_fallback_cooldown_seconds: 1\n" + if err := yaml.Unmarshal([]byte(yamlText), &cfg); err != nil { + t.Fatal(err) + } + be, err := reg.BuildBackend("openai-codex", cfg, ts.Client(), BackendFactoryDeps{}) + if err != nil { + t.Fatal(err) + } + es, err := be.Open(context.Background(), lipapi.Call{ + Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) + if err != nil { + t.Fatal(err) + } + drainOpenAICodexStream(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https", got) + } +} + +func TestOpenAICodexBackendFactory_invalidTransportFromYAML(t *testing.T) { + t.Parallel() + + ts := httptest.NewServer(refbackend.New(refbackend.Config{Token: "sk-codex"}).Handler()) + t.Cleanup(ts.Close) + reg := NewRegistry() + if err := InstallStandardBackendsOn(reg, UpstreamAPIKeys{}); err != nil { + t.Fatal(err) + } + var cfg yaml.Node + yamlText := "base_url: " + ts.URL + "/backend-api/codex\naccess_token: sk-codex\ntransport: quic\n" + if err := yaml.Unmarshal([]byte(yamlText), &cfg); err != nil { + t.Fatal(err) + } + be, err := reg.BuildBackend("openai-codex", cfg, ts.Client(), BackendFactoryDeps{}) + if err != nil { + t.Fatal(err) + } + _, err = be.Open(context.Background(), lipapi.Call{ + Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) + if err == nil { + t.Fatal("expected invalid transport config error") + } +} + func drainOpenAICodexStream(t *testing.T, es lipapi.ManagedEventStream) { t.Helper() for { @@ -340,7 +397,7 @@ func TestOpenAICodexBackendFactory_apiKeyAliasForAccessToken(t *testing.T) { } _, err = be.Open(context.Background(), lipapi.Call{ Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, - }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}}) + }, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}}) if err != nil { t.Fatal(err) } diff --git a/internal/pluginreg/frontends_install.go b/internal/pluginreg/frontends_install.go index 007db64e..98d7bb21 100644 --- a/internal/pluginreg/frontends_install.go +++ b/internal/pluginreg/frontends_install.go @@ -7,6 +7,7 @@ import ( frontgemini "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/gemini" frontopenailegacy "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openailegacy" frontopenairesponses "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openairesponses" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk" ) @@ -18,6 +19,7 @@ func mountOpenAIResponses(mux *http.ServeMux, opts lipsdk.FrontendMountOptions) mux.Handle("/v1/responses", &frontopenairesponses.Handler{ Exec: opts.Exec, DefaultRouteSelector: opts.DefaultRoute, + RoutePrefixes: routeselect.NewPrefixSet(opts.RoutePrefixes), MaxRequestBodyBytes: opts.MaxRequestBodyBytes, TrafficPorts: opts.TrafficPorts, PreRequestKeepalive: opts.PreRequestKeepalive, @@ -34,6 +36,7 @@ func mountOpenAILegacy(mux *http.ServeMux, opts lipsdk.FrontendMountOptions) err mux.Handle("/v1/chat/completions", &frontopenailegacy.Handler{ Exec: opts.Exec, DefaultRouteSelector: opts.DefaultRoute, + RoutePrefixes: routeselect.NewPrefixSet(opts.RoutePrefixes), MaxRequestBodyBytes: opts.MaxRequestBodyBytes, TrafficPorts: opts.TrafficPorts, PreRequestKeepalive: opts.PreRequestKeepalive, @@ -50,6 +53,7 @@ func mountAnthropic(mux *http.ServeMux, opts lipsdk.FrontendMountOptions) error mux.Handle("/v1/messages", &frontanthropic.Handler{ Exec: opts.Exec, DefaultRouteSelector: opts.DefaultRoute, + RoutePrefixes: routeselect.NewPrefixSet(opts.RoutePrefixes), MaxRequestBodyBytes: opts.MaxRequestBodyBytes, TrafficPorts: opts.TrafficPorts, PreRequestKeepalive: opts.PreRequestKeepalive, @@ -66,6 +70,7 @@ func mountGemini(mux *http.ServeMux, opts lipsdk.FrontendMountOptions) error { h := &frontgemini.Handler{ Exec: opts.Exec, DefaultRouteSelector: opts.DefaultRoute, + RoutePrefixes: routeselect.NewPrefixSet(opts.RoutePrefixes), MaxRequestBodyBytes: opts.MaxRequestBodyBytes, TrafficPorts: opts.TrafficPorts, PreRequestKeepalive: opts.PreRequestKeepalive, diff --git a/internal/plugins/backends/openaicodex/attempt.go b/internal/plugins/backends/openaicodex/attempt.go index 6f253a8e..d3666e77 100644 --- a/internal/plugins/backends/openaicodex/attempt.go +++ b/internal/plugins/backends/openaicodex/attempt.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "strings" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/backends/streampeek" @@ -14,12 +15,13 @@ import ( ) type codexOpenEnv struct { - payload Payload - originalModel string - convID string - client *http.Client - endpoint string - downgrade downgradePolicy + payload Payload + originalModel string + convID string + inputFingerprints []string + client *http.Client + endpoint string + downgrade downgradePolicy } func prepareCodexOpenEnv(ctx context.Context, cfg *Config, call lipapi.Call, cand routing.AttemptCandidate, policy downgradePolicy) (*codexOpenEnv, error) { @@ -30,20 +32,23 @@ func prepareCodexOpenEnv(ctx context.Context, cfg *Config, call lipapi.Call, can if err != nil { return nil, err } - originalModel := strings.TrimSpace(cand.Primary.Model) - convID := conversationID(call, originalModel) + logPayloadShape(ctx, &call, payload) + originalModel := normalizeCodexModel(cand.Primary.Model) + inputFingerprints := fingerprintInputItems(payload.Input) + convID := conversationIDForPayloadWithFingerprints(call, originalModel, payload, inputFingerprints) payload.PromptCacheKey = convID client := cfg.HTTPClient if client == nil { client = http.DefaultClient } return &codexOpenEnv{ - payload: payload, - originalModel: originalModel, - convID: convID, - client: client, - endpoint: responsesEndpoint(cfg.BaseURL), - downgrade: policy, + payload: payload, + originalModel: originalModel, + convID: convID, + inputFingerprints: inputFingerprints, + client: client, + endpoint: responsesEndpoint(cfg.BaseURL), + downgrade: policy, }, nil } @@ -109,8 +114,21 @@ func readLimitedClose(resp *http.Response) []byte { return b } +const upstreamErrorBodyMax = 256 + +// truncateErrorMessage bounds upstream/OAuth response text embedded in errors +// so provider error bodies cannot dump multi-KiB of (possibly echoed) content +// into logs. +func truncateErrorMessage(s string, max int) string { + s = strings.TrimSpace(s) + if len(s) <= max { + return s + } + return s[:max] + fmt.Sprintf("…(truncated %d bytes)", len(s)-max) +} + func upstreamHTTPError(status int, body []byte) error { - return fmt.Errorf("%s: upstream HTTP %d: %s", ID, status, strings.TrimSpace(string(body))) + return fmt.Errorf("%s: upstream HTTP %d: %s", ID, status, truncateErrorMessage(string(body), upstreamErrorBodyMax)) } func non2xxOrNil(resp *http.Response) error { @@ -153,9 +171,19 @@ func (a *codexOpenAttempt) openStream(resp *http.Response) (lipapi.ManagedEventS if model == "" { model = a.originalModel } - es := newCodexStream(resp.Body, a.call.MaxPendingWireEvents) - managed := newUsageEstimatingStream(es, a.usageEst, a.call, model) - ev, rerr := managed.Recv(a.ctx) + st := newCodexStream(resp.Body, a.call.MaxPendingWireEvents) + managed, err := openManagedFirstEvent(a.ctx, st, a.usageEst, a.call, model) + if err != nil { + return nil, err + } + return managed, nil +} + +func openManagedFirstEvent(ctx context.Context, es lipapi.ManagedEventStream, usageEst *usageEstimator, call lipapi.Call, model string) (lipapi.ManagedEventStream, error) { + managed := newUsageEstimatingStream(es, usageEst, call, model) + start := time.Now() + ev, rerr := managed.Recv(ctx) + logFirstEventWait(ctx, call, model, start, ev, rerr) if rerr == nil { return streampeek.NewManagedPrependFirst(ev, managed), nil } diff --git a/internal/plugins/backends/openaicodex/attempt_internal_test.go b/internal/plugins/backends/openaicodex/attempt_internal_test.go new file mode 100644 index 00000000..413b33f7 --- /dev/null +++ b/internal/plugins/backends/openaicodex/attempt_internal_test.go @@ -0,0 +1,45 @@ +package openaicodex + +import ( + "strings" + "testing" +) + +func TestUpstreamHTTPError_truncatesLongBody(t *testing.T) { + t.Parallel() + + err := upstreamHTTPError(400, []byte(strings.Repeat("x", 5000))) + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if strings.Contains(msg, strings.Repeat("x", 300)) { + t.Fatalf("error leaks long upstream body (len=%d)", len(msg)) + } + if !strings.Contains(msg, "truncated") { + t.Fatalf("expected truncated marker in error: %q", msg) + } + if !strings.Contains(msg, "400") { + t.Fatalf("expected status 400 in error: %q", msg) + } +} + +func TestUpstreamHTTPError_preservesShortBody(t *testing.T) { + t.Parallel() + + body := `{"error":"bad request"}` + err := upstreamHTTPError(418, []byte(body)) + if err == nil { + t.Fatal("expected error") + } + msg := err.Error() + if !strings.Contains(msg, body) { + t.Fatalf("expected short body preserved verbatim: %q", msg) + } + if strings.Contains(msg, "truncated") { + t.Fatalf("short body must not be marked truncated: %q", msg) + } + if !strings.Contains(msg, "418") { + t.Fatalf("expected status 418 in error: %q", msg) + } +} diff --git a/internal/plugins/backends/openaicodex/authjson.go b/internal/plugins/backends/openaicodex/authjson.go index 6bcc5bc2..d007f1a9 100644 --- a/internal/plugins/backends/openaicodex/authjson.go +++ b/internal/plugins/backends/openaicodex/authjson.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "path/filepath" + "runtime" "strings" ) @@ -74,6 +75,9 @@ func loadAuthJSON(path string) (authJSONCredentials, error) { if path == "" { return authJSONCredentials{}, fmt.Errorf("path is empty") } + if err := checkTokenFilePermissions(path); err != nil { + return authJSONCredentials{}, err + } b, err := os.ReadFile(path) if err != nil { return authJSONCredentials{}, err @@ -123,3 +127,23 @@ func firstNonEmpty(values ...string) string { } return "" } + +// checkTokenFilePermissions rejects token files readable or writable by group +// or other on Unix, mirroring the Codex CLI auth.json guard. On Windows +// (ACL-based permissions, no meaningful Unix mode bits) it is a no-op. +func checkTokenFilePermissions(path string) error { + if runtime.GOOS == "windows" { + return nil + } + info, err := os.Stat(path) + if err != nil { + return nil // let the caller's ReadFile produce the canonical not-exist error + } + if !info.Mode().IsRegular() { + return fmt.Errorf("%s: token file %q is not a regular file", ID, path) + } + if info.Mode().Perm()&0o077 != 0 { + return fmt.Errorf("%s: token file %q is group/other accessible (mode %o); expected 0600", ID, path, info.Mode().Perm()) + } + return nil +} diff --git a/internal/plugins/backends/openaicodex/authjson_internal_test.go b/internal/plugins/backends/openaicodex/authjson_internal_test.go new file mode 100644 index 00000000..4538da68 --- /dev/null +++ b/internal/plugins/backends/openaicodex/authjson_internal_test.go @@ -0,0 +1,55 @@ +package openaicodex + +import ( + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +func TestLoadAuthJSON_rejectsGroupReadableTokenFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix mode bits not modeled on windows") + } + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "auth.json") + if err := os.WriteFile(path, []byte(`{"access_token":"tok"}`), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Chmod(path, 0o644); err != nil { + t.Fatal(err) + } + _, err := loadAuthJSON(path) + if err == nil { + t.Fatal("expected permission error for group/other-readable token file") + } + if !strings.Contains(err.Error(), "group/other") { + t.Fatalf("expected group/other in error, got: %v", err) + } +} + +func TestLoadAuthJSON_acceptsOwnerOnlyTokenFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix mode bits not modeled on windows") + } + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "auth.json") + if err := os.WriteFile(path, []byte(`{"access_token":"tok"}`), 0o600); err != nil { + t.Fatal(err) + } + if err := os.Chmod(path, 0o600); err != nil { + t.Fatal(err) + } + got, err := loadAuthJSON(path) + if err != nil { + t.Fatalf("expected success for owner-only token file, got: %v", err) + } + if got.AccessToken != "tok" { + t.Fatalf("AccessToken = %q want tok", got.AccessToken) + } +} diff --git a/internal/plugins/backends/openaicodex/authjson_test.go b/internal/plugins/backends/openaicodex/authjson_test.go index 3b9960b0..b4ca357d 100644 --- a/internal/plugins/backends/openaicodex/authjson_test.go +++ b/internal/plugins/backends/openaicodex/authjson_test.go @@ -62,7 +62,7 @@ func TestOpen_authJSONPathLoadsToken(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -95,7 +95,7 @@ func TestOpen_explicitAccessTokenOverridesAuthJSON(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -118,7 +118,7 @@ func TestNew_missingAccessTokenInAuthJSON(t *testing.T) { AuthJSONPath: authPath, }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected config error") @@ -144,7 +144,7 @@ func TestOpen_authJSONCamelCaseFields(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -175,7 +175,7 @@ func TestOpen_defaultAuthJSONPathLoadsToken(t *testing.T) { //nolint:paralleltes HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -207,7 +207,7 @@ func TestOpen_explicitAuthJSONPathOverridesDefaultDiscovery(t *testing.T) { //no HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -225,7 +225,7 @@ func TestNew_missingDefaultAuthJSONKeepsAccessTokenError(t *testing.T) { //nolin BaseURL: "http://127.0.0.1", }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected config error") @@ -254,7 +254,7 @@ func TestOpen_explicitAccountIDOverridesAuthJSON(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) diff --git a/internal/plugins/backends/openaicodex/config.go b/internal/plugins/backends/openaicodex/config.go index 79dcbf8b..2065783e 100644 --- a/internal/plugins/backends/openaicodex/config.go +++ b/internal/plugins/backends/openaicodex/config.go @@ -1,7 +1,9 @@ package openaicodex import ( + "fmt" "net/http" + "strings" "time" ) @@ -12,6 +14,19 @@ const ( DefaultOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann" ) +// Transport mode constants for the Codex backend. +const ( + TransportAuto = "auto" + TransportHTTPS = "https" + TransportWebSocket = "websocket" +) + +// DefaultWebSocketFallbackCooldown is the negative-cache window used when an auto +// transport WebSocket attempt fails before the first canonical event. During the +// cooldown, auto mode skips WebSocket and goes straight to HTTPS to avoid +// repeated dial/handshake latency on known-broken environments. +const DefaultWebSocketFallbackCooldown = 300 * time.Second + type Config struct { BaseURL string AccessToken string @@ -35,4 +50,28 @@ type Config struct { GPT55DowngradeSourceModel string GPT55DowngradeTargetModel string PlanTypeHint string + Transport string + ExperimentalWebSocket bool + WebSocketFallbackCooldown time.Duration +} + +// NormalizeTransport returns the effective transport mode for cfg. An empty +// transport defaults to HTTPS. WebSocket and auto probing are experimental and +// must be enabled explicitly so live clients do not hit the WS path by default. +// An unknown value is rejected with an error so it surfaces through the standard +// config-error path. +func NormalizeTransport(raw string, experimentalWebSocket bool) (string, error) { + t := strings.ToLower(strings.TrimSpace(raw)) + if t == "" { + return TransportHTTPS, nil + } + switch t { + case TransportAuto, TransportHTTPS, TransportWebSocket: + if (t == TransportAuto || t == TransportWebSocket) && !experimentalWebSocket { + return "", fmt.Errorf("%s: transport %q requires experimental_websocket: true", ID, t) + } + return t, nil + default: + return "", fmt.Errorf("%s: unknown transport %q (want %s, %s, or %s)", ID, raw, TransportAuto, TransportHTTPS, TransportWebSocket) + } } diff --git a/internal/plugins/backends/openaicodex/config_test.go b/internal/plugins/backends/openaicodex/config_test.go new file mode 100644 index 00000000..36d493bd --- /dev/null +++ b/internal/plugins/backends/openaicodex/config_test.go @@ -0,0 +1,49 @@ +package openaicodex + +import "testing" + +func TestNormalizeTransport_defaultsAndValid(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in string + experimental bool + want string + }{ + {"empty defaults to https", "", false, TransportHTTPS}, + {"auto allowed with experimental websocket", "auto", true, TransportAuto}, + {"https is case insensitive", "HTTPS", false, TransportHTTPS}, + {"websocket trims whitespace", " websocket ", true, TransportWebSocket}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := NormalizeTransport(tc.in, tc.experimental) + if err != nil { + t.Fatalf("NormalizeTransport(%q) err: %v", tc.in, err) + } + if got != tc.want { + t.Fatalf("NormalizeTransport(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestNormalizeTransport_invalidErrors(t *testing.T) { + t.Parallel() + if _, err := NormalizeTransport("quic", true); err == nil { + t.Fatal("expected error for unknown transport") + } +} + +func TestNormalizeTransport_webSocketRequiresExperimentalOptIn(t *testing.T) { + t.Parallel() + for _, transport := range []string{TransportAuto, TransportWebSocket} { + t.Run(transport, func(t *testing.T) { + t.Parallel() + if _, err := NormalizeTransport(transport, false); err == nil { + t.Fatalf("expected %q to require experimental websocket opt-in", transport) + } + }) + } +} diff --git a/internal/plugins/backends/openaicodex/continuation.go b/internal/plugins/backends/openaicodex/continuation.go new file mode 100644 index 00000000..68bf4ea1 --- /dev/null +++ b/internal/plugins/backends/openaicodex/continuation.go @@ -0,0 +1,333 @@ +package openaicodex + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "strings" + "sync" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +const ( + codexContinuationTTL = time.Hour + codexContinuationMaxEntries = 1024 +) + +type wsContinuationStore struct { + mu sync.Mutex + ttl time.Duration + maxEntries int + entries map[wsContinuationKey]wsContinuationEntry + order []wsContinuationKey + now func() time.Time +} + +type wsContinuationKey struct { + sessionID string + model string + accountID string + promptCacheKey string + clientFamily string +} + +type wsContinuationEntry struct { + responseID string + inputFingerprints []string + outputItemFingerprints []string + instructionsFingerprint string + toolsFingerprint string + inFlight bool + expiresAt time.Time +} + +func newWSContinuationStore(ttl time.Duration, maxEntries int) *wsContinuationStore { + if ttl <= 0 { + ttl = codexContinuationTTL + } + if maxEntries <= 0 { + maxEntries = codexContinuationMaxEntries + } + return &wsContinuationStore{ + ttl: ttl, + maxEntries: maxEntries, + entries: make(map[wsContinuationKey]wsContinuationEntry), + now: time.Now, + } +} + +func (s *wsContinuationStore) prepare(ctx context.Context, cfg *Config, call lipapi.Call, payload *Payload) bool { + if payload == nil { + return false + } + return s.prepareWithFingerprints(ctx, cfg, call, payload, fingerprintInputItems(payload.Input)) +} + +func (s *wsContinuationStore) prepareWithFingerprints(ctx context.Context, cfg *Config, call lipapi.Call, payload *Payload, inputFingerprints []string) bool { + if s == nil || payload == nil { + return false + } + key := continuationKeyWithFingerprints(cfg, call, payload, inputFingerprints) + instructionsFingerprint := fingerprintJSON(payload.Instructions) + toolsFingerprint := fingerprintJSON(payload.Tools) + s.mu.Lock() + defer s.mu.Unlock() + s.purgeExpiredLocked() + entry, ok := s.entries[key] + if !ok { + return false + } + if entry.inFlight { + logWSContinuation(ctx, call, payload.Model, "in_flight", len(payload.Input), len(payload.Input), "") + return false + } + s.touchLocked(key) + if entry.instructionsFingerprint != instructionsFingerprint || + entry.toolsFingerprint != toolsFingerprint { + delete(s.entries, key) + logWSContinuation(ctx, call, payload.Model, "static_fingerprint_changed", 0, len(payload.Input), "") + return false + } + baseline := append([]string(nil), entry.inputFingerprints...) + baseline = append(baseline, entry.outputItemFingerprints...) + sliced, ok := sliceInputForContinuation(baseline, payload.Input, inputFingerprints) + mode := "delta_applied" + if !ok { + sliced, ok = sliceInputAfterReplayedOutputItems(entry.inputFingerprints, len(entry.outputItemFingerprints), payload.Input, inputFingerprints) + mode = "delta_applied_replayed_output" + } + if !ok { + delete(s.entries, key) + logWSContinuation(ctx, call, payload.Model, "input_drift", 0, len(payload.Input), "") + return false + } + before := len(payload.Input) + payload.PreviousResponseID = entry.responseID + payload.Input = sliced + entry.inFlight = true + s.entries[key] = entry + logWSContinuation(ctx, call, payload.Model, mode, before, len(payload.Input), entry.responseID) + return true +} + +func (s *wsContinuationStore) record(cfg *Config, call lipapi.Call, payload Payload, responseID string, outputItems ...inputItem) { + s.recordWithFingerprints(cfg, call, payload, fingerprintInputItems(payload.Input), responseID, outputItems...) +} + +func (s *wsContinuationStore) recordWithFingerprints(cfg *Config, call lipapi.Call, payload Payload, inputFingerprints []string, responseID string, outputItems ...inputItem) { + if s == nil { + return + } + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return + } + key := continuationKeyWithFingerprints(cfg, call, &payload, inputFingerprints) + outputItemFingerprints := fingerprintInputItems(outputItems) + instructionsFingerprint := fingerprintJSON(payload.Instructions) + toolsFingerprint := fingerprintJSON(payload.Tools) + expiresAt := s.now().Add(s.ttl) + entry := wsContinuationEntry{ + responseID: responseID, + inputFingerprints: append([]string(nil), inputFingerprints...), + outputItemFingerprints: outputItemFingerprints, + instructionsFingerprint: instructionsFingerprint, + toolsFingerprint: toolsFingerprint, + expiresAt: expiresAt, + } + s.mu.Lock() + defer s.mu.Unlock() + s.purgeExpiredLocked() + s.entries[key] = entry + s.touchLocked(key) + for len(s.entries) > s.maxEntries && len(s.order) > 0 { + oldest := s.order[0] + s.order = s.order[1:] + delete(s.entries, oldest) + } +} + +func (s *wsContinuationStore) invalidate(cfg *Config, call lipapi.Call, payload *Payload) { + if payload == nil { + return + } + s.invalidateWithFingerprints(cfg, call, payload, fingerprintInputItems(payload.Input)) +} + +func (s *wsContinuationStore) invalidateWithFingerprints(cfg *Config, call lipapi.Call, payload *Payload, inputFingerprints []string) { + if s == nil || payload == nil { + return + } + key := continuationKeyWithFingerprints(cfg, call, payload, inputFingerprints) + s.mu.Lock() + defer s.mu.Unlock() + delete(s.entries, key) + out := s.order[:0] + for _, existing := range s.order { + if existing != key { + out = append(out, existing) + } + } + s.order = out +} + +func (s *wsContinuationStore) purgeExpiredLocked() { + now := s.now() + out := s.order[:0] + for _, key := range s.order { + entry, ok := s.entries[key] + if !ok { + continue + } + if !entry.expiresAt.After(now) { + delete(s.entries, key) + continue + } + out = append(out, key) + } + s.order = out +} + +func (s *wsContinuationStore) touchLocked(key wsContinuationKey) { + out := s.order[:0] + for _, existing := range s.order { + if existing != key { + out = append(out, existing) + } + } + out = append(out, key) + s.order = out +} + +func continuationKeyWithFingerprints(cfg *Config, call lipapi.Call, payload *Payload, inputFingerprints []string) wsContinuationKey { + accountID := "" + if cfg != nil { + accountID = strings.TrimSpace(cfg.AccountID) + } + model := "" + promptCacheKey := "" + if payload != nil { + model = strings.TrimSpace(payload.Model) + promptCacheKey = strings.TrimSpace(payload.PromptCacheKey) + } + sessionID := strings.TrimSpace(call.Session.ContinuityKey) + if sessionID == "" { + sessionID = strings.TrimSpace(call.Session.CorrelationID()) + } + if sessionID == "" && payload != nil && len(payload.Input) > 0 { + sessionID = "input:" + firstInputFingerprint(payload.Input, inputFingerprints) + } + return wsContinuationKey{ + sessionID: sessionID, + model: model, + accountID: accountID, + promptCacheKey: promptCacheKey, + clientFamily: continuationClientFamily(call), + } +} + +func continuationClientFamily(call lipapi.Call) string { + for _, key := range []string{"agent", "openai_codex.agent", "user_agent"} { + if raw, ok := call.Extensions[key]; ok { + var value string + if json.Unmarshal(raw, &value) == nil { + if family := normalizeContinuationFamily(value); family != "generic" { + return family + } + } + } + } + if raw, ok := call.Extensions["headers"]; ok { + var headers map[string]string + if json.Unmarshal(raw, &headers) == nil { + for _, key := range []string{"user-agent", "User-Agent"} { + if family := normalizeContinuationFamily(headers[key]); family != "generic" { + return family + } + } + } + } + return "generic" +} + +func normalizeContinuationFamily(candidate string) string { + lowered := strings.ToLower(strings.TrimSpace(candidate)) + switch { + case strings.Contains(lowered, "opencode"): + return "opencode" + case strings.Contains(lowered, "factory-cli"), strings.Contains(lowered, "factory_cli"), strings.Contains(lowered, "factorydroid"): + return "droid" + default: + return "generic" + } +} + +func sliceInputForContinuation(prior []string, current []inputItem, currentFP []string) ([]inputItem, bool) { + if len(prior) == 0 || len(current) == 0 { + return nil, false + } + if len(currentFP) == 0 { + currentFP = fingerprintInputItems(current) + } + common := 0 + for common < len(prior) && common < len(currentFP) && prior[common] == currentFP[common] { + common++ + } + if common < len(prior) || common <= 0 || common >= len(current) { + return nil, false + } + return append([]inputItem(nil), current[common:]...), true +} + +func sliceInputAfterReplayedOutputItems(prior []string, outputItems int, current []inputItem, currentFP []string) ([]inputItem, bool) { + if len(prior) == 0 || outputItems <= 0 || len(current) == 0 { + return nil, false + } + if len(currentFP) == 0 { + currentFP = fingerprintInputItems(current) + } + if len(currentFP) <= len(prior) { + return nil, false + } + for i := range prior { + if prior[i] != currentFP[i] { + return nil, false + } + } + idx := len(prior) + skipped := 0 + for idx < len(current) && skipped < outputItems { + if _, ok := current[idx].(functionCallItem); !ok { + break + } + idx++ + skipped++ + } + if skipped == 0 || idx >= len(current) { + return nil, false + } + if _, ok := current[idx].(functionCallOutputItem); !ok { + return nil, false + } + return append([]inputItem(nil), current[idx:]...), true +} + +func fingerprintInputItems(items []inputItem) []string { + out := make([]string, 0, len(items)) + for _, item := range items { + out = append(out, fingerprintJSON(item)) + } + return out +} + +func fingerprintJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + return "" + } + sum := sha256.Sum256(b) + return hex.EncodeToString(sum[:]) +} diff --git a/internal/plugins/backends/openaicodex/continuation_test.go b/internal/plugins/backends/openaicodex/continuation_test.go new file mode 100644 index 00000000..4dcac46e --- /dev/null +++ b/internal/plugins/backends/openaicodex/continuation_test.go @@ -0,0 +1,325 @@ +package openaicodex + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func TestWSContinuationStore_slicesCompatiblePayload(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{ + ID: "call-1", + Session: lipapi.SessionRef{ + ClientSessionID: "session-1", + }, + Extensions: map[string]json.RawMessage{ + "agent": json.RawMessage(`"opencode"`), + }, + } + cfg := &Config{AccountID: "acct-1"} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "inspect"}, + functionCallItem{Type: "function_call", CallID: "call_1", Name: "bash", Arguments: "{}"}, + functionCallOutputItem{Type: "function_call_output", CallID: "call_1", Output: "ok"}, + }, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + store.record(cfg, call, first, "resp_1") + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if !store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("expected continuation delta") + } + if next.PreviousResponseID != "resp_1" { + t.Fatalf("previous_response_id = %q", next.PreviousResponseID) + } + if len(next.Input) != 1 { + t.Fatalf("delta input len = %d", len(next.Input)) + } + if len(next.Tools) != 1 { + t.Fatalf("continued payload tools len = %d, want preserved tool schema", len(next.Tools)) + } + msg, ok := next.Input[0].(textMessageItem) + if !ok || msg.Content != "continue" { + t.Fatalf("delta input = %#v", next.Input) + } +} + +func TestWSContinuationStore_slicesAfterPreviousOutputItems(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "session-1", + }, + Extensions: map[string]json.RawMessage{ + "agent": json.RawMessage(`"opencode"`), + }, + } + cfg := &Config{AccountID: "acct-1"} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "inspect"}, + }, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + assistantCall := functionCallItem{ + Type: "function_call", + ID: "fc_1", + CallID: "call_fc_1", + Name: "bash", + Arguments: `{"cmd":"pwd"}`, + } + store.record(cfg, call, first, "resp_1", assistantCall) + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), + assistantCall, + functionCallOutputItem{Type: "function_call_output", CallID: "call_fc_1", Output: "ok"}, + textMessageItem{Type: "message", Role: "user", Content: "continue"}, + ) + if !store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("expected continuation delta") + } + if len(next.Input) != 2 { + t.Fatalf("delta input len = %d, input=%#v", len(next.Input), next.Input) + } + if _, ok := next.Input[0].(functionCallOutputItem); !ok { + t.Fatalf("first delta item = %#v, want function call output", next.Input[0]) + } + if msg, ok := next.Input[1].(textMessageItem); !ok || msg.Content != "continue" { + t.Fatalf("second delta item = %#v, want continue message", next.Input[1]) + } +} + +func TestWSContinuationStore_slicesAfterReplayedOutputItemsWithDifferentShape(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "session-1", + }, + Extensions: map[string]json.RawMessage{ + "agent": json.RawMessage(`"opencode"`), + }, + } + cfg := &Config{AccountID: "acct-1"} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "inspect"}, + }, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + store.record(cfg, call, first, "resp_1", functionCallItem{ + Type: "function_call", + ID: "fc_1", + CallID: "call_fc_1", + Name: "bash", + Arguments: `{"cmd":"pwd"}`, + }) + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), + functionCallItem{Type: "function_call", CallID: "call_fc_1", Name: "bash", Arguments: `{"cmd":"pwd"}`}, + functionCallOutputItem{Type: "function_call_output", CallID: "call_fc_1", Output: "ok"}, + textMessageItem{Type: "message", Role: "user", Content: "continue"}, + ) + if !store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("expected continuation delta") + } + if len(next.Input) != 2 { + t.Fatalf("delta input len = %d, input=%#v", len(next.Input), next.Input) + } + if _, ok := next.Input[0].(functionCallOutputItem); !ok { + t.Fatalf("first delta item = %#v, want function call output", next.Input[0]) + } + if msg, ok := next.Input[1].(textMessageItem); !ok || msg.Content != "continue" { + t.Fatalf("second delta item = %#v, want continue message", next.Input[1]) + } +} + +func TestWSContinuationStore_rejectsConcurrentReuseOfPreviousResponseID(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "session-1", + }, + Extensions: map[string]json.RawMessage{ + "agent": json.RawMessage(`"opencode"`), + }, + } + cfg := &Config{AccountID: "acct-1"} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "inspect"}, + }, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + store.record(cfg, call, first, "resp_1") + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if !store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("expected first continuation delta") + } + if next.PreviousResponseID != "resp_1" { + t.Fatalf("previous_response_id = %q", next.PreviousResponseID) + } + + duplicate := first + duplicate.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if store.prepare(context.Background(), cfg, call, &duplicate) { + t.Fatal("duplicate in-flight continuation unexpectedly reused previous_response_id") + } + if duplicate.PreviousResponseID != "" { + t.Fatalf("duplicate previous_response_id = %q", duplicate.PreviousResponseID) + } + + child := first + child.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + store.record(cfg, call, child, "resp_2") + + afterChild := child + afterChild.Input = append(append([]inputItem(nil), child.Input...), textMessageItem{Type: "message", Role: "user", Content: "next"}) + if !store.prepare(context.Background(), cfg, call, &afterChild) { + t.Fatal("expected continuation after child response recorded") + } + if afterChild.PreviousResponseID != "resp_2" { + t.Fatalf("previous_response_id after child = %q", afterChild.PreviousResponseID) + } +} + +func TestWSContinuationStore_invalidatePreparedContinuationClearsInFlight(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "session-1", + }, + Extensions: map[string]json.RawMessage{ + "agent": json.RawMessage(`"opencode"`), + }, + } + cfg := &Config{AccountID: "acct-1"} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "inspect"}, + }, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + store.record(cfg, call, first, "resp_1") + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if !store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("expected continuation delta") + } + store.invalidate(cfg, call, &first) + + afterFailure := first + afterFailure.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if store.prepare(context.Background(), cfg, call, &afterFailure) { + t.Fatal("invalidated continuation entry must not remain reusable") + } + + store.record(cfg, call, first, "resp_1") + afterRecord := first + afterRecord.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if !store.prepare(context.Background(), cfg, call, &afterRecord) { + t.Fatal("newly recorded continuation should be usable after invalidation") + } +} + +func TestWSContinuationStore_usesAuthoritativeSessionBeforeClientHint(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + cfg := &Config{} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "lip-gpt-5.4-mini-stable", + Input: []inputItem{ + textMessageItem{Type: "message", Role: "user", Content: "initial task"}, + textMessageItem{Type: "message", Role: "assistant", Content: "working"}, + }, + } + store.record(cfg, lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "client-session-1", + AuthoritativeSessionID: "proxy-session-1", + }, + }, first, "resp_1") + + next := first + next.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if !store.prepare(context.Background(), cfg, lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "client-session-2", + AuthoritativeSessionID: "proxy-session-1", + }, + }, &next) { + t.Fatal("expected continuation despite changed client hint") + } + if next.PreviousResponseID != "resp_1" { + t.Fatalf("previous_response_id = %q", next.PreviousResponseID) + } + + otherSession := first + otherSession.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if store.prepare(context.Background(), cfg, lipapi.Call{ + Session: lipapi.SessionRef{ + ClientSessionID: "client-session-1", + AuthoritativeSessionID: "proxy-session-2", + }, + }, &otherSession) { + t.Fatal("changed authoritative session must not reuse previous_response_id") + } +} + +func TestWSContinuationStore_rejectsStaticFingerprintDrift(t *testing.T) { + t.Parallel() + store := newWSContinuationStore(time.Minute, 8) + call := lipapi.Call{Session: lipapi.SessionRef{ClientSessionID: "session-1"}} + cfg := &Config{} + first := Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: []inputItem{textMessageItem{Type: "message", Role: "user", Content: "inspect"}}, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } + store.record(cfg, call, first, "resp_1") + + next := first + next.Tools = []toolPayload{{Type: "function", Name: "grep"}} + next.Input = append(append([]inputItem(nil), first.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + if store.prepare(context.Background(), cfg, call, &next) { + t.Fatal("unexpected continuation delta after tools drift") + } + if next.PreviousResponseID != "" { + t.Fatalf("previous_response_id = %q", next.PreviousResponseID) + } +} diff --git a/internal/plugins/backends/openaicodex/debug.go b/internal/plugins/backends/openaicodex/debug.go new file mode 100644 index 00000000..67749210 --- /dev/null +++ b/internal/plugins/backends/openaicodex/debug.go @@ -0,0 +1,152 @@ +package openaicodex + +import ( + "context" + "encoding/json" + "log/slog" + "strconv" + "strings" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func debugTurnsEnabled() bool { + return diag.DebugTurnsEnabled() +} + +func logPayloadShape(ctx context.Context, call *lipapi.Call, payload Payload) { + if !debugTurnsEnabled() || call == nil { + return + } + raw, _ := json.Marshal(payload) + summary := summarizePayload(payload) + slog.DebugContext(ctx, "openaicodex.debug.payload", + "call_id", call.ID, + "trace_id", diag.StableCallID(call), + "a_leg_id", strings.TrimSpace(call.Session.ALegID), + "model", payload.Model, + "payload_bytes", len(raw), + "instructions_bytes", len(payload.Instructions), + "input_text_bytes", summary.inputTextBytes, + "input_items", len(payload.Input), + "input_types", strings.Join(summary.inputTypes, ","), + "function_call_ids", strings.Join(summary.functionCallIDs, ","), + "function_output_ids", strings.Join(summary.functionOutputIDs, ","), + "tools", len(payload.Tools), + "tool_names", strings.Join(summary.toolNames, ","), + "reasoning_effort", reasoningEffort(payload), + "parallel_tool_calls", boolPtrString(payload.ParallelToolCalls), + ) +} + +func logFirstEventWait(ctx context.Context, call lipapi.Call, model string, start time.Time, ev lipapi.Event, err error) { + if !debugTurnsEnabled() { + return + } + attrs := []any{ + "call_id", call.ID, + "trace_id", diag.StableCallID(&call), + "a_leg_id", strings.TrimSpace(call.Session.ALegID), + "model", model, + "duration_ms", time.Since(start).Milliseconds(), + } + if err != nil { + attrs = append(attrs, "status", "error", "error", err.Error()) + } else { + attrs = append(attrs, "status", "ok", "event_kind", string(ev.Kind)) + } + slog.DebugContext(ctx, "openaicodex.debug.first_event", attrs...) +} + +func logWSContinuation(ctx context.Context, call lipapi.Call, model, mode string, inputBefore, inputAfter int, previousResponseID string) { + if !debugTurnsEnabled() { + return + } + slog.DebugContext(ctx, "openaicodex.debug.ws_continuation", + "call_id", call.ID, + "trace_id", diag.StableCallID(&call), + "a_leg_id", strings.TrimSpace(call.Session.ALegID), + "model", model, + "mode", mode, + "input_before", inputBefore, + "input_after", inputAfter, + "previous_response_id", previousResponseID, + ) +} + +type payloadSummary struct { + inputTypes []string + functionCallIDs []string + functionOutputIDs []string + toolNames []string + inputTextBytes int +} + +func summarizePayload(payload Payload) payloadSummary { + typeCounts := map[string]int{} + var functionCallIDs []string + var functionOutputIDs []string + inputTextBytes := 0 + for _, item := range payload.Input { + switch v := item.(type) { + case textMessageItem: + typeCounts[v.Type+":"+v.Role]++ + inputTextBytes += len(v.Content) + case richMessageItem: + typeCounts[v.Type+":"+v.Role]++ + inputTextBytes += richMessageTextBytes(v) + case functionCallItem: + typeCounts[v.Type]++ + inputTextBytes += len(v.Arguments) + functionCallIDs = diag.AppendLimited(functionCallIDs, v.CallID, 12) + case functionCallOutputItem: + typeCounts[v.Type]++ + inputTextBytes += len(v.Output) + functionOutputIDs = diag.AppendLimited(functionOutputIDs, v.CallID, 12) + default: + typeCounts["unknown"]++ + } + } + toolNames := make([]string, 0, min(len(payload.Tools), 12)) + for _, tool := range payload.Tools { + toolNames = diag.AppendLimited(toolNames, tool.Name, 12) + } + return payloadSummary{ + inputTypes: diag.StableCounts(typeCounts), + functionCallIDs: functionCallIDs, + functionOutputIDs: functionOutputIDs, + toolNames: toolNames, + inputTextBytes: inputTextBytes, + } +} + +func richMessageTextBytes(item richMessageItem) int { + total := 0 + for _, block := range item.Content { + switch v := block.(type) { + case inputTextPart: + total += len(v.Text) + case inputImagePart: + total += len(v.ImageURL) + case inputFilePart: + total += len(v.FileData) + len(v.Filename) + } + } + return total +} + +func reasoningEffort(payload Payload) string { + if payload.Reasoning == nil { + return "" + } + return payload.Reasoning.Effort +} + +func boolPtrString(v *bool) string { + if v == nil { + return "" + } + return strconv.FormatBool(*v) +} diff --git a/internal/plugins/backends/openaicodex/downgrade.go b/internal/plugins/backends/openaicodex/downgrade.go index 736f5779..0c998e3b 100644 --- a/internal/plugins/backends/openaicodex/downgrade.go +++ b/internal/plugins/backends/openaicodex/downgrade.go @@ -41,19 +41,33 @@ func (p downgradePolicy) modelForPlan(requested, planType string) string { return p.target } -func (p downgradePolicy) isFreePlanRejection(status int, body string) bool { - if status != http.StatusBadRequest || p.disabled { +func (p downgradePolicy) isReactiveFreePlanRejectionMessage(message string) bool { + if p.disabled { return false } - lower := strings.ToLower(body) - if !strings.Contains(lower, "gpt-5.5") || !strings.Contains(lower, "free") { + lower := strings.ToLower(message) + if !strings.Contains(lower, strings.ToLower(p.source)) || !strings.Contains(lower, "free") { return false } return strings.Contains(lower, "unsupported") || strings.Contains(lower, "not available") } +func (p downgradePolicy) shouldReactiveRetry(originalModel string, alreadyRetried bool, rejectionMessage string) bool { + if alreadyRetried || originalModel != p.source || p.disabled { + return false + } + return p.isReactiveFreePlanRejectionMessage(rejectionMessage) +} + +func (p downgradePolicy) isFreePlanRejection(status int, body string) bool { + if status != http.StatusBadRequest { + return false + } + return p.isReactiveFreePlanRejectionMessage(body) +} + func (p downgradePolicy) retryBody(originalModel string, alreadyRetried bool, status int, body string, payload *Payload) ([]byte, bool, error) { - if alreadyRetried || originalModel != p.source || !p.isFreePlanRejection(status, body) { + if status != http.StatusBadRequest || !p.shouldReactiveRetry(originalModel, alreadyRetried, body) { return nil, false, nil } payload.Model = p.target diff --git a/internal/plugins/backends/openaicodex/downgrade_policy_test.go b/internal/plugins/backends/openaicodex/downgrade_policy_test.go index 720fd2a0..2c21c62a 100644 --- a/internal/plugins/backends/openaicodex/downgrade_policy_test.go +++ b/internal/plugins/backends/openaicodex/downgrade_policy_test.go @@ -27,3 +27,52 @@ func TestDowngradePolicy_disabled(t *testing.T) { t.Fatalf("disabled retryBody = (%v, %v, %v), want (nil, false, nil)", body, ok, err) } } + +func TestDowngradePolicy_isReactiveFreePlanRejectionMessage(t *testing.T) { + t.Parallel() + p := newDowngradePolicy(Config{}) + if !p.isReactiveFreePlanRejectionMessage("gpt-5.5 is not available on free plan") { + t.Fatal("expected free-plan rejection message match") + } + if p.isReactiveFreePlanRejectionMessage("model not found") { + t.Fatal("unrelated message must not match") + } + disabled := newDowngradePolicy(Config{GPT55DowngradeDisabled: true}) + if disabled.isReactiveFreePlanRejectionMessage("gpt-5.5 is not available on free plan") { + t.Fatal("disabled policy must not match") + } +} + +func TestDowngradePolicy_isFreePlanRejection_customSource(t *testing.T) { + t.Parallel() + p := newDowngradePolicy(Config{ + GPT55DowngradeSourceModel: "custom-src", + GPT55DowngradeTargetModel: "custom-dst", + }) + body := `{"error":{"message":"custom-src is not available on free plan"}}` + if !p.isFreePlanRejection(400, body) { + t.Fatal("expected custom source rejection") + } + if p.isFreePlanRejection(400, `{"error":{"message":"gpt-5.5 is not available on free plan"}}`) { + t.Fatal("default source token must not match custom policy") + } +} + +func TestDowngradePolicy_shouldReactiveRetry(t *testing.T) { + t.Parallel() + p := newDowngradePolicy(Config{}) + msg := "gpt-5.5 is not available on free plan" + if !p.shouldReactiveRetry("gpt-5.5", false, msg) { + t.Fatal("expected reactive retry") + } + if p.shouldReactiveRetry("gpt-5.5", true, msg) { + t.Fatal("already retried must not retry") + } + if p.shouldReactiveRetry("gpt-5.3-codex-spark", false, msg) { + t.Fatal("non-source model must not retry") + } + custom := newDowngradePolicy(Config{GPT55DowngradeSourceModel: "custom-src"}) + if !custom.shouldReactiveRetry("custom-src", false, "custom-src is not available on free plan") { + t.Fatal("custom source should retry") + } +} diff --git a/internal/plugins/backends/openaicodex/gpt55_downgrade_test.go b/internal/plugins/backends/openaicodex/gpt55_downgrade_test.go index 8f753bfc..a2dcea0e 100644 --- a/internal/plugins/backends/openaicodex/gpt55_downgrade_test.go +++ b/internal/plugins/backends/openaicodex/gpt55_downgrade_test.go @@ -194,7 +194,7 @@ func TestGPT55Downgrade_nonSourceModelDoesNotDowngrade(t *testing.T) { HTTPClient: srv.Client(), }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected error") @@ -206,3 +206,88 @@ func TestGPT55Downgrade_nonSourceModelDoesNotDowngrade(t *testing.T) { t.Fatalf("calls = %d, want 1", calls.Load()) } } + +func TestGPT55Downgrade_reactive400CustomSourceRetriesWithTarget(t *testing.T) { + t.Parallel() + var calls atomic.Int32 + var lastModel atomic.Value + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/responses") { + http.NotFound(w, r) + return + } + calls.Add(1) + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read body", http.StatusBadRequest) + return + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + model, _ := payload["model"].(string) + lastModel.Store(model) + if model == "custom-src" { + http.Error(w, `{"error":{"message":"custom-src is not available on free plan"}}`, http.StatusBadRequest) + return + } + refbackend.New(refbackend.Config{Token: "tok"}).Handler().ServeHTTP(w, r) + })) + t.Cleanup(srv.Close) + + be := backend.New(backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + GPT55DowngradeSourceModel: "custom-src", + GPT55DowngradeTargetModel: "custom-dst", + }) + es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "custom-src"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if calls.Load() != 2 { + t.Fatalf("calls = %d, want 2", calls.Load()) + } + got, ok := lastModel.Load().(string) + if !ok { + t.Fatal("lastModel not string") + } + if got != "custom-dst" { + t.Fatalf("final model = %q, want custom-dst", got) + } +} + +func TestGPT55Downgrade_reactiveWSFreePlanRetriesWithTarget(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "tok", + OutputText: "ok", + ForcedWSRejectModel: "gpt-5.5", + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: ts.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.5"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := requestModel(t, srv); got != "gpt-5.4" { + t.Fatalf("final model = %q, want gpt-5.4", got) + } +} diff --git a/internal/plugins/backends/openaicodex/headers.go b/internal/plugins/backends/openaicodex/headers.go index 93f547f9..08f9d672 100644 --- a/internal/plugins/backends/openaicodex/headers.go +++ b/internal/plugins/backends/openaicodex/headers.go @@ -11,44 +11,86 @@ import ( const ( codexBetaHeader = "responses=experimental" + codexWSBetaHeader = "responses-websocket-mode=v2" codexOriginator = "codex_cli_rs" codexVersionHeader = "0.0.0" codexTaskTypeHeader = "standard" ) -func responsesEndpoint(baseURL string) string { +var codexUserAgentValue = fmt.Sprintf("%s/%s (%s; %s)", codexOriginator, codexVersionHeader, runtime.GOOS, runtime.GOARCH) + +func normalizedResponsesBase(baseURL string) string { base := strings.TrimRight(strings.TrimSpace(baseURL), "/") - if strings.HasSuffix(base, "/responses") { - return base + if !strings.HasSuffix(base, "/responses") { + base += "/responses" } - return base + "/responses" + return base +} + +func responsesEndpoint(baseURL string) string { + return normalizedResponsesBase(baseURL) } func applyCodexHeaders(req *http.Request, cfg Config, conversationID string) { - req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(cfg.AccessToken)) - req.Header.Set("OpenAI-Beta", codexBetaHeader) - req.Header.Set("Accept", "text/event-stream") - req.Header.Set("Content-Type", "application/json") - req.Header.Set("version", codexVersionHeader) - req.Header.Set("originator", codexOriginator) - req.Header.Set("User-Agent", codexUserAgent()) - req.Header.Set("conversation_id", conversationID) - req.Header.Set("session_id", conversationID) - req.Header.Set("Codex-Task-Type", codexTaskTypeHeader) + mergeCodexHeaders(req.Header, cfg, conversationID) +} + +// codexHeaders builds the Codex request headers shared by HTTPS and WebSocket +// transports. WebSocket dial uses this directly since it has no *http.Request. +func codexHeaders(cfg Config, conversationID string) http.Header { + h := http.Header{} + mergeCodexHeaders(h, cfg, conversationID) + return h +} + +func codexWSHeaders(cfg Config, conversationID string) http.Header { + h := codexHeaders(cfg, conversationID) + // The WebSocket handshake uses a different beta opt-in than HTTPS Responses. + // The Python connector sends only responses-websocket-mode=v2 for WS, so this + // intentionally replaces the HTTPS responses=experimental value. + h.Set("OpenAI-Beta", codexWSBetaHeader) + return h +} + +func mergeCodexHeaders(h http.Header, cfg Config, conversationID string) { + h.Set("Authorization", "Bearer "+strings.TrimSpace(cfg.AccessToken)) + h.Set("OpenAI-Beta", codexBetaHeader) + h.Set("Accept", "text/event-stream") + h.Set("Content-Type", "application/json") + h.Set("version", codexVersionHeader) + h.Set("originator", codexOriginator) + h.Set("User-Agent", codexUserAgent()) + h.Set("conversation_id", conversationID) + h.Set("session_id", conversationID) + h.Set("Codex-Task-Type", codexTaskTypeHeader) if id := strings.TrimSpace(cfg.AccountID); id != "" { - req.Header.Set("chatgpt-account-id", id) + h.Set("chatgpt-account-id", id) } } func codexUserAgent() string { - return fmt.Sprintf("%s/%s (%s; %s)", codexOriginator, codexVersionHeader, runtime.GOOS, runtime.GOARCH) + return codexUserAgentValue } -func conversationID(call lipapi.Call, model string) string { +// primaryConversationID returns the first proxy-recognized conversation affinity +// identifier carried on the call: ContinuityKey, then the session correlation id, +// then a non-generated call ID. It returns "" when none apply, so callers can +// fall back to model- or input-derived ids. +func primaryConversationID(call lipapi.Call) string { + if id := strings.TrimSpace(call.Session.ContinuityKey); id != "" { + return id + } if id := strings.TrimSpace(call.Session.CorrelationID()); id != "" { return id } - if id := strings.TrimSpace(call.ID); id != "" { + if id := strings.TrimSpace(call.ID); id != "" && !isGeneratedCallID(id) { + return id + } + return "" +} + +func conversationID(call lipapi.Call, model string) string { + if id := primaryConversationID(call); id != "" { return id } model = strings.TrimSpace(model) @@ -57,3 +99,48 @@ func conversationID(call lipapi.Call, model string) string { } return "lip-" + model } + +func conversationIDForPayload(call lipapi.Call, model string, payload Payload) string { + return conversationIDForPayloadWithFingerprints(call, model, payload, nil) +} + +func conversationIDForPayloadWithFingerprints(call lipapi.Call, model string, payload Payload, inputFingerprints []string) string { + if id := primaryConversationID(call); id != "" { + return id + } + if len(payload.Input) > 0 { + fp := firstInputFingerprint(payload.Input, inputFingerprints) + if len(fp) > 16 { + fp = fp[:16] + } + return "lip-" + strings.TrimSpace(model) + "-" + fp + } + return conversationID(call, model) +} + +func firstInputFingerprint(input []inputItem, inputFingerprints []string) string { + if len(inputFingerprints) > 0 { + return inputFingerprints[0] + } + if len(input) == 0 { + return "" + } + return fingerprintJSON(input[0]) +} + +func isGeneratedCallID(id string) bool { + // Heuristic only: generated canonical call IDs currently look like + // call_. A user-provided ID with the same shape is treated as + // generated so it does not become Codex conversation affinity state; callers + // that need stable affinity should set Session.ContinuityKey or the + // authoritative/correlation session fields instead. + if !strings.HasPrefix(id, "call_") || len(id) <= len("call_") { + return false + } + for _, r := range id[len("call_"):] { + if (r < '0' || r > '9') && (r < 'a' || r > 'f') { + return false + } + } + return true +} diff --git a/internal/plugins/backends/openaicodex/headers_internal_test.go b/internal/plugins/backends/openaicodex/headers_internal_test.go new file mode 100644 index 00000000..c7c32943 --- /dev/null +++ b/internal/plugins/backends/openaicodex/headers_internal_test.go @@ -0,0 +1,152 @@ +package openaicodex + +import ( + "strings" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +type unmarshalableInputItem struct { + Ch chan int `json:"ch"` +} + +func (unmarshalableInputItem) inputItem() {} + +func TestConversationIDForPayload_shortFingerprintDoesNotPanic(t *testing.T) { + t.Parallel() + got := conversationIDForPayload(lipapi.Call{}, "gpt-test", Payload{ + Input: []inputItem{unmarshalableInputItem{Ch: make(chan int)}}, + }) + if got != "lip-gpt-test-" { + t.Fatalf("conversation id = %q, want empty fingerprint suffix without panic", got) + } +} + +func TestConversationIDForPayload_truncatesFingerprint(t *testing.T) { + t.Parallel() + got := conversationIDForPayload(lipapi.Call{}, "gpt-test", Payload{ + Input: []inputItem{textMessageItem{Type: "message", Role: "user", Content: "hello"}}, + }) + if !strings.HasPrefix(got, "lip-gpt-test-") { + t.Fatalf("conversation id = %q, want lip-gpt-test prefix", got) + } + if suffix := strings.TrimPrefix(got, "lip-gpt-test-"); len(suffix) != 16 { + t.Fatalf("fingerprint suffix length = %d, want 16", len(suffix)) + } +} + +func TestConversationID_precedence(t *testing.T) { + t.Parallel() + const ( + genID = "call_deadbeefdeadbeef" + userID = "user-req-123" + ) + tests := []struct { + name string + call lipapi.Call + model string + want string + }{ + { + name: "continuity key wins", + call: lipapi.Call{ID: userID, Session: lipapi.SessionRef{ContinuityKey: "ck", AuthoritativeSessionID: "auth"}}, + model: "gpt-test", + want: "ck", + }, + { + name: "correlation id wins when no continuity", + call: lipapi.Call{ID: userID, Session: lipapi.SessionRef{AuthoritativeSessionID: "auth"}}, + model: "gpt-test", + want: "auth", + }, + { + name: "non-generated call id wins", + call: lipapi.Call{ID: userID}, + model: "gpt-test", + want: userID, + }, + { + name: "generated call id skipped falls back to model", + call: lipapi.Call{ID: genID}, + model: "gpt-test", + want: "lip-gpt-test", + }, + { + name: "empty model defaults to codex suffix", + call: lipapi.Call{ID: genID}, + model: "", + want: "lip-codex", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := conversationID(tt.call, tt.model) + if got != tt.want { + t.Errorf("conversationID = %q, want %q", got, tt.want) + } + }) + } +} + +func TestConversationIDForPayload_precedence(t *testing.T) { + t.Parallel() + const ( + genID = "call_deadbeefdeadbeef" + userID = "user-req-123" + ) + withInput := Payload{Input: []inputItem{textMessageItem{Type: "message", Role: "user", Content: "hello"}}} + tests := []struct { + name string + call lipapi.Call + model string + payload Payload + want string + }{ + { + name: "continuity key wins over fingerprint", + call: lipapi.Call{ID: genID, Session: lipapi.SessionRef{ContinuityKey: "ck"}}, + model: "gpt-test", + payload: withInput, + want: "ck", + }, + { + name: "correlation id wins over fingerprint", + call: lipapi.Call{ID: genID, Session: lipapi.SessionRef{AuthoritativeSessionID: "auth"}}, + model: "gpt-test", + payload: withInput, + want: "auth", + }, + { + name: "non-generated call id wins over fingerprint", + call: lipapi.Call{ID: userID}, + model: "gpt-test", + payload: withInput, + want: userID, + }, + { + name: "generated call id with no input delegates to conversationID", + call: lipapi.Call{ID: genID}, + model: "gpt-test", + payload: Payload{}, + want: "lip-gpt-test", + }, + { + name: "generated call id no input empty model delegates to conversationID", + call: lipapi.Call{ID: genID}, + model: "", + payload: Payload{}, + want: "lip-codex", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := conversationIDForPayload(tt.call, tt.model, tt.payload) + if got != tt.want { + t.Errorf("conversationIDForPayload = %q, want %q", got, tt.want) + } + }) + } +} diff --git a/internal/plugins/backends/openaicodex/main_test.go b/internal/plugins/backends/openaicodex/main_test.go new file mode 100644 index 00000000..50006983 --- /dev/null +++ b/internal/plugins/backends/openaicodex/main_test.go @@ -0,0 +1,11 @@ +package openaicodex + +import ( + "testing" + + "go.uber.org/goleak" +) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/internal/plugins/backends/openaicodex/managed_oauth_files.go b/internal/plugins/backends/openaicodex/managed_oauth_files.go index 86ef4765..6c8c87a6 100644 --- a/internal/plugins/backends/openaicodex/managed_oauth_files.go +++ b/internal/plugins/backends/openaicodex/managed_oauth_files.go @@ -22,6 +22,12 @@ func loadManagedAccounts(dir string, filter []string) ([]managedAccount, error) if ent.IsDir() || !strings.HasSuffix(strings.ToLower(ent.Name()), ".json") { continue } + if ent.Type()&os.ModeSymlink != 0 { + // security: skip symlinked account files so a planted symlink cannot + // read targets outside the managed-oauth storage directory. + skipped++ + continue + } path := filepath.Join(dir, ent.Name()) acct, ok, err := parseManagedAccountFile(path) if err != nil { @@ -49,6 +55,9 @@ func loadManagedAccounts(dir string, filter []string) ([]managedAccount, error) } func parseManagedAccountFile(path string) (managedAccount, bool, error) { + if err := checkTokenFilePermissions(path); err != nil { + return managedAccount{}, false, err + } b, err := os.ReadFile(path) if err != nil { return managedAccount{}, false, err diff --git a/internal/plugins/backends/openaicodex/managed_oauth_internal_test.go b/internal/plugins/backends/openaicodex/managed_oauth_internal_test.go index 099974ac..1688cb2c 100644 --- a/internal/plugins/backends/openaicodex/managed_oauth_internal_test.go +++ b/internal/plugins/backends/openaicodex/managed_oauth_internal_test.go @@ -422,3 +422,69 @@ func TestPersistQuotaHeaders_updatesCachedPlanType(t *testing.T) { t.Fatalf("usage percent: %q", got["x-codex-primary-used-percent"]) } } + +func TestLoadManagedAccounts_rejectsGroupReadableAccountFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("unix mode bits not modeled on windows") + } + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "a.json") + if err := os.WriteFile(path, []byte(`{"account_id":"a","access_token":"tok"}`), 0o644); err != nil { + t.Fatal(err) + } + if err := os.Chmod(path, 0o644); err != nil { + t.Fatal(err) + } + _, err := loadManagedAccounts(dir, nil) + if err == nil { + t.Fatal("expected permission error for group/other-readable account file") + } + if !strings.Contains(err.Error(), "group/other") { + t.Fatalf("expected group/other in error, got: %v", err) + } +} + +func TestLoadManagedAccounts_skipsSymlinkedAccountFile(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("symlink creation needs admin/developer mode on windows") + } + t.Parallel() + + outsideDir := t.TempDir() + outsideFile := filepath.Join(outsideDir, "real.json") + if err := os.WriteFile(outsideFile, []byte(`{"account_id":"leaked","access_token":"secret-leaked"}`), 0o600); err != nil { + t.Fatal(err) + } + if err := os.Chmod(outsideFile, 0o600); err != nil { + t.Fatal(err) + } + + storageDir := t.TempDir() + realFile := filepath.Join(storageDir, "real.json") + if err := os.WriteFile(realFile, []byte(`{"account_id":"ok","access_token":"tok-ok"}`), 0o600); err != nil { + t.Fatal(err) + } + if err := os.Chmod(realFile, 0o600); err != nil { + t.Fatal(err) + } + linkFile := filepath.Join(storageDir, "link.json") + if err := os.Symlink(outsideFile, linkFile); err != nil { + t.Fatal(err) + } + + accounts, err := loadManagedAccounts(storageDir, nil) + if err != nil { + t.Fatalf("loadManagedAccounts: %v", err) + } + if len(accounts) != 1 { + t.Fatalf("expected exactly 1 account (symlink skipped), got %d: %+v", len(accounts), accounts) + } + if accounts[0].ID != "ok" { + t.Fatalf("account ID = %q want ok", accounts[0].ID) + } + if accounts[0].AccessToken == "secret-leaked" || accounts[0].ID == "leaked" { + t.Fatalf("symlink target was followed: %+v", accounts[0]) + } +} diff --git a/internal/plugins/backends/openaicodex/managed_oauth_test.go b/internal/plugins/backends/openaicodex/managed_oauth_test.go index c925db4c..c838054b 100644 --- a/internal/plugins/backends/openaicodex/managed_oauth_test.go +++ b/internal/plugins/backends/openaicodex/managed_oauth_test.go @@ -12,6 +12,7 @@ import ( "testing" "time" + "github.com/gorilla/websocket" "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" backend "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/backends/openaicodex" refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" @@ -65,13 +66,16 @@ func TestManagedOAuth_loadsAccountFilesAndUsesTokenAndAccountHeader(t *testing.T cfg.HTTPClient = ts.Client() be := backend.New(cfg) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) } drainEvents(t, es) got := srv.LatestRequest() + if got.Transport != "https" { + t.Fatalf("transport: %q, want https (websocket is experimental opt-in)", got.Transport) + } if got.Authorization != "Bearer tok-one" { t.Fatalf("authorization: %q", got.Authorization) } @@ -80,6 +84,102 @@ func TestManagedOAuth_loadsAccountFilesAndUsesTokenAndAccountHeader(t *testing.T } } +func TestManagedOAuth_websocketModeUsesManagedTokenAndAccountHeader(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeAccountFile(t, dir, "acct1.json", managedAccountFixture{ + AccountID: "acct-one", + AccessToken: "tok-one", + }) + + srv := refbackend.New(refbackend.Config{Token: "tok-one", OutputText: "managed-ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + cfg := managedOAuthCfg(dir) + cfg.BaseURL = ts.URL + "/backend-api/codex" + cfg.HTTPClient = ts.Client() + cfg.Transport = backend.TransportWebSocket + cfg.ExperimentalWebSocket = true + be := backend.New(cfg) + es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + got := srv.LatestRequest() + if got.Transport != "websocket" { + t.Fatalf("transport: %q, want websocket", got.Transport) + } + if got.Authorization != "Bearer tok-one" { + t.Fatalf("authorization: %q", got.Authorization) + } + if got.ChatGPTAccountID != "acct-one" { + t.Fatalf("account id: %q", got.ChatGPTAccountID) + } +} + +func TestManagedOAuth_httpsModeSkipsWebSocket(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeAccountFile(t, dir, "acct1.json", managedAccountFixture{ + AccountID: "acct-one", + AccessToken: "tok-one", + }) + + srv := refbackend.New(refbackend.Config{Token: "tok-one", OutputText: "managed-http-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + cfg := managedOAuthCfg(dir) + cfg.BaseURL = ts.URL + "/backend-api/codex" + cfg.HTTPClient = ts.Client() + cfg.Transport = backend.TransportHTTPS + be := backend.New(cfg) + es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport: %q, want https", got) + } +} + +func TestManagedOAuth_autoFallsBackToHTTPSOnWSFailure(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeAccountFile(t, dir, "acct1.json", managedAccountFixture{ + AccountID: "acct-one", + AccessToken: "tok-one", + }) + + srv := refbackend.New(refbackend.Config{Token: "tok-one", OutputText: "managed-http-ok", ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + cfg := managedOAuthCfg(dir) + cfg.BaseURL = ts.URL + "/backend-api/codex" + cfg.HTTPClient = ts.Client() + cfg.Transport = backend.TransportAuto + cfg.ExperimentalWebSocket = true + be := backend.New(cfg) + es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport: %q, want https fallback", got) + } +} + func TestManagedOAuth_roundRobinCyclesTwoAccounts(t *testing.T) { t.Parallel() dir := t.TempDir() @@ -115,7 +215,7 @@ func TestManagedOAuth_roundRobinCyclesTwoAccounts(t *testing.T) { cfg.HTTPClient = srv.Client() cfg.ManagedOAuthSelectionStrategy = "round-robin" be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es1, err := be.Open(context.Background(), sampleCall(), cand) if err != nil { @@ -179,7 +279,7 @@ func TestManagedOAuth_401OnFirstAccountRetriesSecondAndMarksFirstInvalid(t *test cfg.BaseURL = srv.URL + "/backend-api/codex" cfg.HTTPClient = srv.Client() be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es, err := be.Open(context.Background(), sampleCall(), cand) if err != nil { @@ -241,7 +341,7 @@ func TestManagedOAuth_429WithRetryAfterRetriesSecondAndCooldownExcludesFirst(t * cfg.BaseURL = srv.URL + "/backend-api/codex" cfg.HTTPClient = srv.Client() be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es, err := be.Open(context.Background(), sampleCall(), cand) if err != nil { @@ -278,7 +378,7 @@ func TestManagedOAuth_noUsableAccountsAllowFallbackFalseErrors(t *testing.T) { cfg := managedOAuthCfg(dir) be := backend.New(cfg) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected error") @@ -310,7 +410,7 @@ func TestManagedOAuth_all429sStopsAfterAccountBudget(t *testing.T) { cfg.HTTPClient = srv.Client() be := backend.New(cfg) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected error") @@ -342,7 +442,7 @@ func TestManagedOAuth_allowFallbackTrueUsesAuthJSONPath(t *testing.T) { cfg.AuthJSONPath = authPath be := backend.New(cfg) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -395,7 +495,7 @@ func TestManagedOAuth_sessionAffinityReusesAccountAcrossCalls(t *testing.T) { cfg.HTTPClient = srv.Client() cfg.ManagedOAuthSelectionStrategy = "session-affinity" be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es1, err := be.Open(context.Background(), callWithSession("sess-sticky"), cand) if err != nil { @@ -451,7 +551,7 @@ func TestManagedOAuth_sessionAffinityDifferentSessions(t *testing.T) { cfg.HTTPClient = srv.Client() cfg.ManagedOAuthSelectionStrategy = "session-affinity" be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es1, err := be.Open(context.Background(), callWithSession("sess-a"), cand) if err != nil { @@ -513,7 +613,7 @@ func TestManagedOAuth_sessionAffinity401RotatesForSameSession(t *testing.T) { cfg.HTTPClient = srv.Client() cfg.ManagedOAuthSelectionStrategy = "session-affinity" be := backend.New(cfg) - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} call := callWithSession("sess-retry") es, err := be.Open(context.Background(), call, cand) @@ -564,7 +664,7 @@ func TestManagedOAuth_quotaHeadersPersistedOnSuccess(t *testing.T) { cfg.HTTPClient = ts.Client() be := backend.New(cfg) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -600,3 +700,85 @@ func TestManagedOAuth_quotaHeadersPersistedOnSuccess(t *testing.T) { t.Fatal("access_token field lost") } } + +func TestManagedOAuth_websocket401OnFirstAccountRetriesSecondAndMarksFirstInvalid(t *testing.T) { + t.Parallel() + dir := t.TempDir() + writeAccountFile(t, dir, "bad.json", managedAccountFixture{ + AccountID: "acct-bad", + AccessToken: "tok-bad", + }) + writeAccountFile(t, dir, "good.json", managedAccountFixture{ + AccountID: "acct-good", + AccessToken: "tok-good", + }) + + var lastAuth atomic.Value + var wsAttempts atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if websocket.IsWebSocketUpgrade(r) { + wsAttempts.Add(1) + auth := r.Header.Get("Authorization") + lastAuth.Store(auth) + if auth == "Bearer tok-bad" { + http.Error(w, `{"error":"invalid"}`, http.StatusUnauthorized) + return + } + refbackend.New(refbackend.Config{Token: "tok-good"}).Handler().ServeHTTP(w, r) + return + } + if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/responses") { + http.NotFound(w, r) + return + } + auth := r.Header.Get("Authorization") + lastAuth.Store(auth) + if auth == "Bearer tok-bad" { + http.Error(w, `{"error":"invalid"}`, http.StatusUnauthorized) + return + } + refbackend.New(refbackend.Config{Token: "tok-good"}).Handler().ServeHTTP(w, r) + })) + t.Cleanup(srv.Close) + + cfg := managedOAuthCfg(dir) + cfg.BaseURL = srv.URL + "/backend-api/codex" + cfg.HTTPClient = srv.Client() + cfg.Transport = backend.TransportWebSocket + cfg.ExperimentalWebSocket = true + be := backend.New(cfg) + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} + + es, err := be.Open(context.Background(), sampleCall(), cand) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if wsAttempts.Load() < 2 { + t.Fatalf("expected WS retry on first open, wsAttempts=%d", wsAttempts.Load()) + } + got, ok := lastAuth.Load().(string) + if !ok { + t.Fatal("lastAuth not string") + } + if got != "Bearer tok-good" { + t.Fatalf("open auth: %q", got) + } + + before := wsAttempts.Load() + es2, err := be.Open(context.Background(), sampleCall(), cand) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es2) + got, ok = lastAuth.Load().(string) + if !ok { + t.Fatal("lastAuth not string") + } + if got != "Bearer tok-good" { + t.Fatalf("second open auth: %q", got) + } + if wsAttempts.Load()-before != 1 { + t.Fatalf("second open should not retry bad account, wsAttempts=%d", wsAttempts.Load()-before) + } +} diff --git a/internal/plugins/backends/openaicodex/oauth.go b/internal/plugins/backends/openaicodex/oauth.go index 2c8999cb..0ed6a980 100644 --- a/internal/plugins/backends/openaicodex/oauth.go +++ b/internal/plugins/backends/openaicodex/oauth.go @@ -55,7 +55,7 @@ func refreshOAuthAccessToken(ctx context.Context, cfg Config, client *http.Clien defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 8192)) if resp.StatusCode < 200 || resp.StatusCode > 299 { - return cfg, fmt.Errorf("refresh HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + return cfg, fmt.Errorf("refresh HTTP %d: %s", resp.StatusCode, truncateErrorMessage(string(respBody), upstreamErrorBodyMax)) } var parsed map[string]json.RawMessage diff --git a/internal/plugins/backends/openaicodex/oauth_refresh_test.go b/internal/plugins/backends/openaicodex/oauth_refresh_test.go index 181bfd7e..20fb4f4d 100644 --- a/internal/plugins/backends/openaicodex/oauth_refresh_test.go +++ b/internal/plugins/backends/openaicodex/oauth_refresh_test.go @@ -44,7 +44,7 @@ func TestOpen_oauthRefreshRetriesOnceOn401(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -79,7 +79,7 @@ func TestOpen_oauthRefreshFailureReturnsError(t *testing.T) { HTTPClient: ts.Client(), }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected refresh failure error") @@ -117,7 +117,7 @@ func TestOpen_401WithoutRefreshTokenDoesNotRetry(t *testing.T) { HTTPClient: ts.Client(), }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected auth error") @@ -126,3 +126,38 @@ func TestOpen_401WithoutRefreshTokenDoesNotRetry(t *testing.T) { t.Fatalf("refresh calls: %d", refreshCalls.Load()) } } + +func TestOpen_oauthRefreshErrorTruncatesLongBody(t *testing.T) { + t.Parallel() + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, strings.Repeat("z", 5000)) + })) + t.Cleanup(tokenSrv.Close) + + srv := refbackend.New(refbackend.Config{Token: "expected-token"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "old-token", + RefreshToken: "bad-refresh", + OAuthTokenURL: tokenSrv.URL, + HTTPClient: ts.Client(), + }) + _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }) + if err == nil { + t.Fatal("expected refresh failure error") + } + msg := err.Error() + if strings.Contains(msg, strings.Repeat("z", 300)) { + t.Fatalf("error leaks long token-endpoint body (len=%d)", len(msg)) + } + if !strings.Contains(msg, "refresh") { + t.Fatalf("expected refresh context in error: %q", msg) + } +} diff --git a/internal/plugins/backends/openaicodex/payload.go b/internal/plugins/backends/openaicodex/payload.go index af075f32..2801c419 100644 --- a/internal/plugins/backends/openaicodex/payload.go +++ b/internal/plugins/backends/openaicodex/payload.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" - "github.com/matdev83/go-llm-interactive-proxy/internal/core/jsonpresence" "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" ) @@ -17,421 +16,129 @@ const defaultCodexInstruction = "You are Codex, based on GPT-5. You are running // emitted with strict=false and parallel_tool_calls defaults to true. const ExtToolStrict = "openai_codex.tool_strict" -type Payload struct { - Model string `json:"model"` - Stream bool `json:"stream"` - Store bool `json:"store"` - Instructions string `json:"instructions"` - Input []inputItem `json:"input"` - Tools []toolPayload `json:"tools,omitempty"` - Reasoning *reasoningSpec `json:"reasoning,omitempty"` - ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` - PromptCacheKey string `json:"prompt_cache_key,omitempty"` -} - -type inputItem interface { - inputItem() -} - -type textMessageItem struct { - Type string `json:"type"` - Role string `json:"role"` - Content string `json:"content"` -} - -func (textMessageItem) inputItem() {} - -type richMessageItem struct { - Type string `json:"type"` - Role string `json:"role"` - Content []contentBlock `json:"content"` -} - -func (richMessageItem) inputItem() {} - -type functionCallOutputItem struct { - Type string `json:"type"` - CallID string `json:"call_id"` - Output string `json:"output"` -} - -func (functionCallOutputItem) inputItem() {} +// ExtIgnoreUnsupportedGenParams is the canonical-call extension key (bool). When +// true, temperature, top_p, and max_output_tokens are dropped instead of failing +// payload build; used by codex-client-compat for OpenCode and similar clients. +const ExtIgnoreUnsupportedGenParams = "openai_codex.ignore_unsupported_gen_params" -type functionCallItem struct { - Type string `json:"type"` - ID string `json:"id,omitempty"` - CallID string `json:"call_id"` - Name string `json:"name"` - Arguments string `json:"arguments"` -} - -func (functionCallItem) inputItem() {} - -type contentBlock interface { - contentBlock() -} - -type inputTextPart struct { - Type string `json:"type"` - Text string `json:"text"` -} - -func (inputTextPart) contentBlock() {} - -type inputImagePart struct { - Type string `json:"type"` - ImageURL string `json:"image_url"` -} - -func (inputImagePart) contentBlock() {} - -type inputFilePart struct { - Type string `json:"type"` - FileData string `json:"file_data"` - Filename string `json:"filename"` +type Payload struct { + Model string `json:"model"` + Stream bool `json:"stream,omitempty"` + Store bool `json:"store"` + Instructions string `json:"instructions"` + Input []inputItem `json:"input"` + Tools []toolPayload `json:"tools,omitempty"` + ToolChoice string `json:"tool_choice,omitempty"` + Reasoning *reasoningSpec `json:"reasoning,omitempty"` + Include []string `json:"include,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + PromptCacheKey string `json:"prompt_cache_key,omitempty"` + PreviousResponseID string `json:"previous_response_id,omitempty"` } -func (inputFilePart) contentBlock() {} - type reasoningSpec struct { - Effort string `json:"effort"` + Effort string `json:"effort"` + Summary string `json:"summary,omitempty"` } -type toolPayload struct { - Type string `json:"type"` - Name string `json:"name"` - Description string `json:"description,omitempty"` - Parameters map[string]any `json:"parameters"` - Strict bool `json:"strict"` +// normalizeCodexModel strips client provider-namespace prefixes (e.g. OpenCode's +// "openai/") that the Codex Responses API rejects. +func normalizeCodexModel(model string) string { + return strings.TrimPrefix(strings.TrimSpace(model), "openai/") } func PayloadForCall(call *lipapi.Call, cand routing.AttemptCandidate, cfg Config) (Payload, error) { if call == nil { return Payload{}, fmt.Errorf("%s: nil call", ID) } - model := strings.TrimSpace(cand.Primary.Model) + model := normalizeCodexModel(cand.Primary.Model) if model == "" { return Payload{}, fmt.Errorf("%s: model is required", ID) } + if err := validateUnsupportedGenParams(call); err != nil { + return Payload{}, err + } items, err := buildInputItems(call) if err != nil { return Payload{}, err } - hermesMode := isHermesToolStrict(call) + toolStrictDisabled := false + if strict, ok := extensionBool(call, ExtToolStrict); ok { + toolStrictDisabled = !strict + } p := Payload{ Model: model, Stream: true, - Instructions: resolveInstructions(call.Instructions), + Instructions: resolveCodexInstructions(call), Input: items, } if len(call.Tools) > 0 { - tools, err := buildTools(call.Tools, hermesMode) + tools, err := buildTools(call.Tools, toolStrictDisabled) if err != nil { return Payload{}, err } p.Tools = tools - if hermesMode && p.ParallelToolCalls == nil { + // Codex accepts tool_choice only when callable tools are present. OpenCode + // sends no-tools turns during compaction/continuation, and forwarding + // tool_choice:auto in that state can make the upstream model behave as if a + // tool protocol still exists. Keep the absence of tools explicit. + p.ToolChoice = "auto" + if toolStrictDisabled && p.ParallelToolCalls == nil { t := true p.ParallelToolCalls = &t } } if effort := strings.TrimSpace(call.Options.ReasoningEffort); effort != "" { - p.Reasoning = &reasoningSpec{Effort: effort} + p.Reasoning = &reasoningSpec{Effort: effort, Summary: "auto"} } else if effort = strings.TrimSpace(cfg.DefaultReasoningEffort); effort != "" { - p.Reasoning = &reasoningSpec{Effort: effort} - } - if call.Options.Temperature != nil { - return Payload{}, fmt.Errorf("%s: temperature is not supported by Codex", ID) - } - if call.Options.MaxOutputTokens != nil && !hasAnthropicModelExtension(call) { - return Payload{}, fmt.Errorf("%s: max output tokens are not supported by Codex", ID) + p.Reasoning = &reasoningSpec{Effort: effort, Summary: "auto"} } - if call.Options.TopP != nil { - return Payload{}, fmt.Errorf("%s: top_p is not supported by Codex", ID) - } - if call.Options.ParallelToolCalls != nil { - p.ParallelToolCalls = call.Options.ParallelToolCalls - } - return p, nil -} - -func resolveInstructions(insts []lipapi.Message) string { - if text := joinInstructionText(insts); text != "" { - return text + if p.Reasoning != nil { + p.Include = []string{"reasoning.encrypted_content"} } - return defaultCodexInstruction -} - -func joinInstructionText(insts []lipapi.Message) string { - var b strings.Builder - for _, m := range insts { - for _, p := range m.Parts { - if p.Kind != lipapi.PartText { - continue - } - if strings.TrimSpace(p.Text) == "" { - continue - } - if b.Len() > 0 { - b.WriteString("\n\n") - } - b.WriteString(p.Text) + if len(call.Tools) > 0 { + if call.Options.ParallelToolCalls != nil { + p.ParallelToolCalls = call.Options.ParallelToolCalls + } else if p.ParallelToolCalls == nil { + v := false + p.ParallelToolCalls = &v } } - return strings.TrimSpace(b.String()) -} - -func hasAnthropicModelExtension(call *lipapi.Call) bool { - if call == nil || call.Extensions == nil { - return false - } - _, ok := call.Extensions["anthropic.model"] - return ok + return p, nil } -func isHermesToolStrict(call *lipapi.Call) bool { +func extensionBool(call *lipapi.Call, key string) (bool, bool) { if call == nil || call.Extensions == nil { - return false + return false, false } - raw, ok := call.Extensions[ExtToolStrict] + raw, ok := call.Extensions[key] if !ok { - return false + return false, false } var b bool if err := json.Unmarshal(raw, &b); err != nil { - return false + return false, false } - return !b + return b, true } -func buildInputItems(call *lipapi.Call) ([]inputItem, error) { - out := make([]inputItem, 0, len(call.Messages)) - for _, m := range call.Messages { - if m.Role == lipapi.RoleTool { - for _, p := range m.Parts { - if p.Kind != lipapi.PartToolResult { - return nil, fmt.Errorf("%s: unsupported tool part kind %q", ID, p.Kind) - } - out = append(out, functionCallOutputItem{ - Type: "function_call_output", - CallID: p.ToolCallID, - Output: toolResultString(p), - }) - } - continue - } - if m.Role == lipapi.RoleAssistant && len(m.Parts) > 0 { - fcs, ok, err := assistantFunctionCallItems(m.Parts) - if err != nil { - return nil, err - } - if ok { - out = append(out, fcs...) - continue - } - } - item, err := messageToInputItem(m) - if err != nil { - return nil, err - } - out = append(out, item) +func validateUnsupportedGenParams(call *lipapi.Call) error { + if ignore, ok := extensionBool(call, ExtIgnoreUnsupportedGenParams); ok && ignore { + return nil } - return out, nil -} - -func assistantFunctionCallItems(parts []lipapi.Part) ([]inputItem, bool, error) { - out := make([]inputItem, 0, len(parts)) - for _, p := range parts { - item, ok, err := partToFunctionCallItem(p) - if err != nil { - return nil, false, err - } - if !ok { - return nil, false, nil - } - out = append(out, item) - } - if len(out) == 0 { - return nil, false, nil - } - return out, true, nil -} - -func partToFunctionCallItem(p lipapi.Part) (inputItem, bool, error) { - if p.Kind != lipapi.PartJSON || len(p.Content) == 0 { - return nil, false, nil - } - var probe struct { - Type string `json:"type"` - } - if err := json.Unmarshal(p.Content, &probe); err != nil { - return nil, false, nil - } - if t := strings.TrimSpace(probe.Type); t != "" && t != "function_call" { - return nil, false, nil - } - var v struct { - ID string `json:"id"` - CallID string `json:"call_id"` - Name string `json:"name"` - Arguments json.RawMessage `json:"arguments"` - } - if err := json.Unmarshal(p.Content, &v); err != nil { - return nil, false, fmt.Errorf("%s: function_call json: %w", ID, err) - } - callID := strings.TrimSpace(v.CallID) - name := strings.TrimSpace(v.Name) - if callID == "" || name == "" { - return nil, false, fmt.Errorf("%s: function_call requires call_id and name", ID) - } - argStr := "{}" - if jsonpresence.IsPresentNonNullJSON(v.Arguments) { - switch v.Arguments[0] { - case '"': - var s string - if err := json.Unmarshal(v.Arguments, &s); err != nil { - return nil, false, fmt.Errorf("%s: function_call arguments: %w", ID, err) - } - argStr = s - default: - if !json.Valid(v.Arguments) { - return nil, false, fmt.Errorf("%s: function_call arguments must be JSON", ID) - } - argStr = string(v.Arguments) - } - } - item := functionCallItem{ - Type: "function_call", - CallID: callID, - Name: name, - Arguments: argStr, - } - if id := strings.TrimSpace(v.ID); id != "" { - item.ID = id - } - return item, true, nil -} - -func toolResultString(p lipapi.Part) string { - if len(p.Content) == 0 { - return "" - } - return string(p.Content) -} - -func messageToInputItem(m lipapi.Message) (inputItem, error) { - role := roleString(m.Role) - if len(m.Parts) == 1 && m.Parts[0].Kind == lipapi.PartText { - return textMessageItem{ - Type: "message", - Role: role, - Content: m.Parts[0].Text, - }, nil - } - content, err := partsToContentList(m.Parts) - if err != nil { - return nil, err - } - return richMessageItem{ - Type: "message", - Role: role, - Content: content, - }, nil -} - -func roleString(r lipapi.Role) string { - switch r { - case lipapi.RoleUser: - return "user" - case lipapi.RoleAssistant: - return "assistant" - case lipapi.RoleSystem: - return "system" - default: - return "user" - } -} - -func partsToContentList(parts []lipapi.Part) ([]contentBlock, error) { - out := make([]contentBlock, 0, len(parts)) - for _, p := range parts { - switch p.Kind { - case lipapi.PartText: - if strings.TrimSpace(p.Text) == "" { - continue - } - out = append(out, inputTextPart{Type: "input_text", Text: p.Text}) - case lipapi.PartImageRef: - out = append(out, inputImagePart{ - Type: "input_image", - ImageURL: p.ImageRef, - }) - case lipapi.PartFileRef: - b64, fname, err := fileDataFromPart(p) - if err != nil { - return nil, err - } - out = append(out, inputFilePart{ - Type: "input_file", - FileData: b64, - Filename: fname, - }) - default: - return nil, fmt.Errorf("%s: unsupported part kind %q", ID, p.Kind) - } - } - return out, nil -} - -func fileDataFromPart(p lipapi.Part) (dataB64, filename string, err error) { - filename = strings.TrimSpace(p.FileName) - ref := p.FileRef - if strings.HasPrefix(ref, "data:") { - _, b64, ok := stripDataURLBase64(ref) - if !ok { - return "", "", fmt.Errorf("%s: invalid data URL in file part", ID) - } - return b64, filename, nil - } - return "", "", fmt.Errorf("%s: file part requires a data URL", ID) -} - -func stripDataURLBase64(dataURL string) (mime, b64 string, ok bool) { - rest, ok := strings.CutPrefix(dataURL, "data:") - if !ok { - return "", "", false + var unsupported []string + if call.Options.Temperature != nil { + unsupported = append(unsupported, "temperature") } - mime, enc, found := strings.Cut(rest, ";") - if !found { - return "", "", false + if call.Options.TopP != nil { + unsupported = append(unsupported, "top_p") } - const prefix = "base64," - encBody, ok := strings.CutPrefix(enc, prefix) - if !ok { - return "", "", false + if call.Options.MaxOutputTokens != nil { + unsupported = append(unsupported, "max_output_tokens") } - return mime, encBody, true -} - -func buildTools(tools []lipapi.ToolDef, strictFalse bool) ([]toolPayload, error) { - out := make([]toolPayload, 0, len(tools)) - for _, t := range tools { - var schema map[string]any - if len(t.Parameters) > 0 { - if err := json.Unmarshal(t.Parameters, &schema); err != nil { - return nil, fmt.Errorf("%s: tool %q parameters: %w", ID, t.Name, err) - } - } - if schema == nil { - schema = map[string]any{} - } - out = append(out, toolPayload{ - Type: "function", - Name: t.Name, - Description: t.Description, - Parameters: schema, - Strict: !strictFalse, - }) + if len(unsupported) == 0 { + return nil } - return out, nil + return fmt.Errorf("%s: unsupported generation parameter(s) %s (Codex Responses API); set extension %q to ignore", + ID, strings.Join(unsupported, ", "), ExtIgnoreUnsupportedGenParams) } diff --git a/internal/plugins/backends/openaicodex/payload_input.go b/internal/plugins/backends/openaicodex/payload_input.go new file mode 100644 index 00000000..94065533 --- /dev/null +++ b/internal/plugins/backends/openaicodex/payload_input.go @@ -0,0 +1,496 @@ +package openaicodex + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/jsonpresence" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +type inputItem interface { + inputItem() +} + +type textMessageItem struct { + Type string `json:"type"` + Role string `json:"role"` + Content string `json:"content"` +} + +func (textMessageItem) inputItem() {} + +type richMessageItem struct { + Type string `json:"type"` + Role string `json:"role"` + Content []contentBlock `json:"content"` +} + +func (richMessageItem) inputItem() {} + +type functionCallOutputItem struct { + Type string `json:"type"` + CallID string `json:"call_id"` + Output string `json:"output"` +} + +func (functionCallOutputItem) inputItem() {} + +type functionCallItem struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + CallID string `json:"call_id"` + Name string `json:"name"` + Arguments string `json:"arguments"` +} + +func (functionCallItem) inputItem() {} + +type contentBlock interface { + contentBlock() +} + +type inputTextPart struct { + Type string `json:"type"` + Text string `json:"text"` +} + +func (inputTextPart) contentBlock() {} + +type inputImagePart struct { + Type string `json:"type"` + ImageURL string `json:"image_url"` +} + +func (inputImagePart) contentBlock() {} + +type inputFilePart struct { + Type string `json:"type"` + FileData string `json:"file_data"` + Filename string `json:"filename"` +} + +func (inputFilePart) contentBlock() {} + +// resolveCodexInstructions builds the Codex `instructions` field. The Codex Responses API +// rejects system-role items in `input`, so system messages from the conversation are folded +// into instructions (deduplicated against explicit instructions, e.g. the codex-client-compat +// bridge). Falls back to the default Codex instruction only when no system content is present. +func resolveCodexInstructions(call *lipapi.Call) string { + instructions := joinInstructionText(call.Instructions) + for _, sysText := range systemMessageTexts(call.Messages) { + if sysText == "" || instructionHasBlock(instructions, sysText) { + continue + } + if instructions != "" { + instructions += "\n\n" + sysText + } else { + instructions = sysText + } + } + if strings.TrimSpace(instructions) == "" { + return defaultCodexInstruction + } + return instructions +} + +// instructionHasBlock reports whether instructions already contains block as a complete +// \n\n-delimited block. Substring containment is intentionally NOT used: a short system +// message that happens to be a substring of a longer instruction block (e.g. "Be concise." +// within "Be concise and helpful.") must still be merged rather than silently dropped. Only +// an exact full-block match (e.g. the codex-client-compat bridge re-sent verbatim) is a dup. +func instructionHasBlock(instructions, block string) bool { + if block == "" { + return false + } + for part := range strings.SplitSeq(instructions, "\n\n") { + if part == block { + return true + } + } + return false +} + +func systemMessageTexts(msgs []lipapi.Message) []string { + var out []string + for _, m := range msgs { + if m.Role != lipapi.RoleSystem { + continue + } + var b strings.Builder + for _, p := range m.Parts { + if p.Kind != lipapi.PartText || strings.TrimSpace(p.Text) == "" { + continue + } + if b.Len() > 0 { + b.WriteByte('\n') + } + b.WriteString(p.Text) + } + if b.Len() > 0 { + out = append(out, b.String()) + } + } + return out +} + +func joinInstructionText(insts []lipapi.Message) string { + var b strings.Builder + for _, m := range insts { + for _, p := range m.Parts { + if p.Kind != lipapi.PartText { + continue + } + if strings.TrimSpace(p.Text) == "" { + continue + } + if b.Len() > 0 { + b.WriteString("\n\n") + } + b.WriteString(p.Text) + } + } + return strings.TrimSpace(b.String()) +} + +func buildInputItems(call *lipapi.Call) ([]inputItem, error) { + out := make([]inputItem, 0, len(call.Messages)) + hasTools := len(call.Tools) > 0 + for _, m := range call.Messages { + if m.Role == lipapi.RoleSystem { + continue + } + if m.Role == lipapi.RoleTool { + for _, p := range m.Parts { + if p.Kind != lipapi.PartToolResult { + return nil, fmt.Errorf("%s: unsupported tool part kind %q", ID, p.Kind) + } + if !hasTools { + // No-tools requests are a real client state, not malformed history: + // OpenCode can resend prior tool transcripts while asking the model + // to continue without exposing callable tools. Codex must see those + // records as plain conversation text, because a function_call_output + // without matching tools can stall the turn or make the model emit raw + // tool protocol back to the user. + out = append(out, textMessageItem{ + Type: "message", + Role: "user", + Content: noToolsToolResultText(p), + }) + continue + } + out = append(out, functionCallOutputItem{ + Type: "function_call_output", + CallID: p.ToolCallID, + Output: toolResultString(p), + }) + } + continue + } + if m.Role == lipapi.RoleAssistant && len(m.Parts) > 0 { + items, ok, err := assistantFunctionCallItems(m.Parts, hasTools) + if err != nil { + return nil, err + } + if ok { + out = append(out, items...) + continue + } + } + item, err := messageToInputItem(m) + if err != nil { + return nil, err + } + out = append(out, item) + } + return out, nil +} + +func assistantFunctionCallItems(parts []lipapi.Part, hasTools bool) ([]inputItem, bool, error) { + out := make([]inputItem, 0, len(parts)) + contentParts := make([]lipapi.Part, 0, len(parts)) + sawFunctionCall := false + flushContent := func() error { + if len(contentParts) == 0 { + return nil + } + item, err := messageToInputItem(lipapi.Message{Role: lipapi.RoleAssistant, Parts: contentParts}) + if err != nil { + return err + } + out = append(out, item) + contentParts = contentParts[:0] + return nil + } + for _, p := range parts { + item, ok, err := partToFunctionCallItem(p) + if err != nil { + return nil, false, err + } + if !ok { + contentParts = append(contentParts, p) + continue + } + if !hasTools { + sawFunctionCall = true + // Preserve the fact that a prior assistant action happened, but do not + // send Codex a structured function_call when this request has no tool + // schema. The structured form is reserved for tool-enabled turns where + // the backend can safely continue the protocol. + contentParts = append(contentParts, lipapi.TextPart(noToolsFunctionCallText(item))) + continue + } + if err := flushContent(); err != nil { + return nil, false, err + } + sawFunctionCall = true + out = append(out, item) + } + if !sawFunctionCall { + return nil, false, nil + } + if err := flushContent(); err != nil { + return nil, false, err + } + return out, true, nil +} + +func noToolsFunctionCallText(item inputItem) string { + fc, ok := item.(functionCallItem) + if !ok { + return "Prior assistant tool call (tools unavailable in this request)." + } + var b strings.Builder + b.WriteString("Prior assistant tool call (tools unavailable in this request).") + if strings.TrimSpace(fc.CallID) != "" { + b.WriteString(" call_id=") + b.WriteString(strings.TrimSpace(fc.CallID)) + b.WriteByte('.') + } + if strings.TrimSpace(fc.Name) != "" { + b.WriteString(" name=") + b.WriteString(strings.TrimSpace(fc.Name)) + b.WriteByte('.') + } + if strings.TrimSpace(fc.Arguments) != "" { + b.WriteString(" arguments=") + b.WriteString(strings.TrimSpace(fc.Arguments)) + } + return b.String() +} + +func noToolsToolResultText(p lipapi.Part) string { + var b strings.Builder + b.WriteString("Prior tool output (tools unavailable in this request).") + if id := strings.TrimSpace(p.ToolCallID); id != "" { + b.WriteString(" call_id=") + b.WriteString(id) + b.WriteByte('.') + } + if len(p.Content) > 0 { + b.WriteByte('\n') + b.WriteString(toolResultDisplayText(p.Content)) + } + return b.String() +} + +func toolResultDisplayText(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + if raw[0] == '"' { + var s string + if json.Unmarshal(raw, &s) == nil { + return s + } + } + return string(raw) +} + +func partToFunctionCallItem(p lipapi.Part) (inputItem, bool, error) { + if p.Kind != lipapi.PartJSON || len(p.Content) == 0 { + return nil, false, nil + } + var v struct { + Type string `json:"type"` + ID string `json:"id"` + CallID string `json:"call_id"` + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + Function struct { + Name string `json:"name"` + Arguments json.RawMessage `json:"arguments"` + } `json:"function"` + } + if err := json.Unmarshal(p.Content, &v); err != nil { + return nil, false, nil + } + t := strings.TrimSpace(v.Type) + // Accept Responses-style ("function_call" or empty) and Chat Completions-style + // ("function") assistant tool calls. Any other concrete type is not a function call. + if t != "" && t != "function_call" && t != "function" { + return nil, false, nil + } + // Chat Completions carries the call id as "id" and the name/arguments under + // "function"; Responses carries them at the top level as "call_id"/"name". + callID := strings.TrimSpace(v.CallID) + if callID == "" { + callID = strings.TrimSpace(v.ID) + } + name := strings.TrimSpace(v.Name) + args := v.Arguments + if name == "" { + name = strings.TrimSpace(v.Function.Name) + args = v.Function.Arguments + } + if callID == "" || name == "" { + return nil, false, fmt.Errorf("%s: function_call requires call_id and name", ID) + } + argStr := "{}" + if jsonpresence.IsPresentNonNullJSON(args) { + switch args[0] { + case '"': + var s string + if err := json.Unmarshal(args, &s); err != nil { + return nil, false, fmt.Errorf("%s: function_call arguments: %w", ID, err) + } + argStr = s + default: + argStr = string(args) + } + } + item := functionCallItem{ + Type: "function_call", + CallID: callID, + Name: name, + Arguments: argStr, + } + // Preserve the Responses-style item id only when it is distinct from the call_id + // (i.e. a separate call_id was supplied). For Chat Completions the id IS the call + // id, so do not duplicate it as the item id. + if strings.TrimSpace(v.ID) != "" && strings.TrimSpace(v.CallID) != "" { + item.ID = strings.TrimSpace(v.ID) + } + return item, true, nil +} + +func toolResultString(p lipapi.Part) string { + output := "" + if len(p.Content) > 0 { + output = string(p.Content) + } + outputRaw, err := json.Marshal(output) + if err != nil { + return output + } + payload := map[string]json.RawMessage{"output": outputRaw} + var existing map[string]json.RawMessage + if json.Unmarshal(p.Content, &existing) == nil { + if raw, ok := existing["exit_code"]; ok { + payload["exit_code"] = raw + } + if raw, ok := existing["workdir"]; ok { + payload["workdir"] = raw + } + } + raw, err := json.Marshal(payload) + if err != nil { + return output + } + return string(raw) +} + +func messageToInputItem(m lipapi.Message) (inputItem, error) { + role := roleString(m.Role) + if len(m.Parts) == 1 && m.Parts[0].Kind == lipapi.PartText { + return textMessageItem{ + Type: "message", + Role: role, + Content: m.Parts[0].Text, + }, nil + } + content, err := partsToContentList(m.Parts) + if err != nil { + return nil, err + } + return richMessageItem{ + Type: "message", + Role: role, + Content: content, + }, nil +} + +func roleString(r lipapi.Role) string { + switch r { + case lipapi.RoleUser: + return "user" + case lipapi.RoleAssistant: + return "assistant" + case lipapi.RoleSystem: + return "system" + default: + return "user" + } +} + +func partsToContentList(parts []lipapi.Part) ([]contentBlock, error) { + out := make([]contentBlock, 0, len(parts)) + for _, p := range parts { + switch p.Kind { + case lipapi.PartText: + if strings.TrimSpace(p.Text) == "" { + continue + } + out = append(out, inputTextPart{Type: "input_text", Text: p.Text}) + case lipapi.PartImageRef: + out = append(out, inputImagePart{ + Type: "input_image", + ImageURL: p.ImageRef, + }) + case lipapi.PartFileRef: + b64, fname, err := fileDataFromPart(p) + if err != nil { + return nil, err + } + out = append(out, inputFilePart{ + Type: "input_file", + FileData: b64, + Filename: fname, + }) + default: + return nil, fmt.Errorf("%s: unsupported part kind %q", ID, p.Kind) + } + } + return out, nil +} + +func fileDataFromPart(p lipapi.Part) (dataB64, filename string, err error) { + filename = strings.TrimSpace(p.FileName) + ref := p.FileRef + if strings.HasPrefix(ref, "data:") { + _, b64, ok := stripDataURLBase64(ref) + if !ok { + return "", "", fmt.Errorf("%s: invalid data URL in file part", ID) + } + return b64, filename, nil + } + return "", "", fmt.Errorf("%s: file part requires a data URL", ID) +} + +func stripDataURLBase64(dataURL string) (mime, b64 string, ok bool) { + rest, ok := strings.CutPrefix(dataURL, "data:") + if !ok { + return "", "", false + } + mime, enc, found := strings.Cut(rest, ";") + if !found { + return "", "", false + } + const prefix = "base64," + encBody, ok := strings.CutPrefix(enc, prefix) + if !ok { + return "", "", false + } + return mime, encBody, true +} diff --git a/internal/plugins/backends/openaicodex/payload_test.go b/internal/plugins/backends/openaicodex/payload_test.go index 28f82ebb..1f62225f 100644 --- a/internal/plugins/backends/openaicodex/payload_test.go +++ b/internal/plugins/backends/openaicodex/payload_test.go @@ -28,7 +28,7 @@ func assertJSONEqual(t *testing.T, got, want []byte) { func payloadInputJSON(t *testing.T, call lipapi.Call) json.RawMessage { t.Helper() payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err != nil { t.Fatal(err) @@ -87,11 +87,39 @@ func TestPayloadInputWireShape_assistantFunctionCallHistory(t *testing.T) { }}, }, }, + Tools: []lipapi.ToolDef{{Name: "get_weather", Parameters: json.RawMessage(`{"type":"object"}`)}}, }) want := `[{"type":"message","role":"user","content":"hi"},{"type":"function_call","id":"fc_1","call_id":"call_1","name":"get_weather","arguments":"{\"city\":\"NYC\"}"}]` assertJSONEqual(t, got, []byte(want)) } +func TestPayloadInputWireShape_assistantChatCompletionsToolCall(t *testing.T) { + t.Parallel() + got := payloadInputJSON(t, lipapi.Call{ + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("run it")}}, + { + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: []byte(`{"id":"call_abc","type":"function","function":{"name":"bash","arguments":"{\"command\":\"echo pong\"}"}}`), + }}, + }, + { + Role: lipapi.RoleTool, + Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: []byte("pong\n"), + }}, + }, + }, + Tools: []lipapi.ToolDef{{Name: "bash", Parameters: json.RawMessage(`{"type":"object"}`)}}, + }) + want := `[{"type":"message","role":"user","content":"run it"},{"type":"function_call","call_id":"call_abc","name":"bash","arguments":"{\"command\":\"echo pong\"}"},{"type":"function_call_output","call_id":"call_abc","output":"{\"output\":\"pong\\n\"}"}]` + assertJSONEqual(t, got, []byte(want)) +} + func TestPayloadInputWireShape_rejectsNonFunctionAssistantJSON(t *testing.T) { t.Parallel() _, err := backend.PayloadForCall(&lipapi.Call{ @@ -103,7 +131,7 @@ func TestPayloadInputWireShape_rejectsNonFunctionAssistantJSON(t *testing.T) { }}, }}, }, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err == nil { t.Fatal("expected error") @@ -127,11 +155,73 @@ func TestPayloadInputWireShape_toolResultMessage(t *testing.T) { }}, }, }, + Tools: []lipapi.ToolDef{{Name: "get_weather", Parameters: json.RawMessage(`{"type":"object"}`)}}, + }) + want := `[{"type":"message","role":"user","content":"call the tool"},{"type":"function_call_output","call_id":"call_1","output":"{\"output\":\"{\\\"ok\\\":true}\"}"}]` + assertJSONEqual(t, got, []byte(want)) +} + +func TestPayloadInputWireShape_noToolsFlattensToolProtocolHistory(t *testing.T) { + t.Parallel() + // OpenCode can resend historical assistant tool calls and tool results during + // no-tools continuation/compaction turns. That history is still useful context, + // but Codex must not receive it as function_call/function_call_output when the + // request exposes zero tools. The structured shape caused no-tools turns to + // hang and sometimes made raw tool-call syntax leak back as assistant text. + got := payloadInputJSON(t, lipapi.Call{ + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("continue")}}, + { + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: []byte(`{"id":"call_abc","type":"function","function":{"name":"bash","arguments":"{\"command\":\"echo pong\"}"}}`), + }}, + }, + { + Role: lipapi.RoleTool, + Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: []byte(`{"matches":100}`), + }}, + }, + }, }) - want := `[{"type":"message","role":"user","content":"call the tool"},{"type":"function_call_output","call_id":"call_1","output":"{\"ok\":true}"}]` + want := `[{"type":"message","role":"user","content":"continue"},{"type":"message","role":"assistant","content":"Prior assistant tool call (tools unavailable in this request). call_id=call_abc. name=bash. arguments={\"command\":\"echo pong\"}"},{"type":"message","role":"user","content":"Prior tool output (tools unavailable in this request). call_id=call_abc.\n{\"matches\":100}"}]` assertJSONEqual(t, got, []byte(want)) } +func TestPayloadForCall_noToolsOmitsToolChoice(t *testing.T) { + t.Parallel() + // `tool_choice:auto` is correct only alongside a non-empty `tools` array. A + // no-tools OpenCode turn should be an ordinary text continuation; sending an + // auto tool choice in that state tells Codex that a tool protocol may still be + // available and reintroduces the slow/invalid no-tools behavior this connector + // already hit in live sessions. + payload, err := backend.PayloadForCall(&lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("continue")}, + }}, + }, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), `"tool_choice"`) { + t.Fatalf("no-tools payload must omit tool_choice: %s", raw) + } + if strings.Contains(string(raw), `"parallel_tool_calls"`) { + t.Fatalf("no-tools payload must omit parallel_tool_calls: %s", raw) + } +} + func TestPayloadForCall_modelInstructionsReasoningTemperatureToolsMultimodal(t *testing.T) { t.Parallel() parallel := true @@ -152,14 +242,14 @@ func TestPayloadForCall_modelInstructionsReasoningTemperatureToolsMultimodal(t * Tools: []lipapi.ToolDef{{ Name: "get_weather", Description: "Get weather", - Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}}}`), + Parameters: json.RawMessage(`{"type":"object","properties":{"city":{"type":"string"}},"required":["city"],"additionalProperties":false}`), }}, Options: lipapi.GenerationOptions{ ReasoningEffort: "high", ParallelToolCalls: ¶llel, }, } - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} payload, err := backend.PayloadForCall(&call, cand, backend.Config{DefaultReasoningEffort: "medium"}) if err != nil { t.Fatal(err) @@ -170,11 +260,14 @@ func TestPayloadForCall_modelInstructionsReasoningTemperatureToolsMultimodal(t * } s := string(raw) for _, want := range []string{ - `"model":"gpt-5.3-codex"`, + `"model":"gpt-5.3-codex-spark"`, `"store":false`, `"instructions":"custom codex instructions"`, `"reasoning"`, `"effort":"high"`, + `"summary":"auto"`, + `"include":["reasoning.encrypted_content"]`, + `"tool_choice":"auto"`, `"parallel_tool_calls":true`, `input_image`, `input_file`, @@ -218,6 +311,12 @@ func TestPayloadForCall_modelInstructionsReasoningTemperatureToolsMultimodal(t * } } +func compatIgnoreGenParamsExt() map[string]json.RawMessage { + return map[string]json.RawMessage{ + backend.ExtIgnoreUnsupportedGenParams: json.RawMessage(`true`), + } +} + func TestPayloadForCall_rejectsMaxOutputTokens(t *testing.T) { t.Parallel() maxTok := 512 @@ -231,12 +330,15 @@ func TestPayloadForCall_rejectsMaxOutputTokens(t *testing.T) { _, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ Primary: routing.Primary{Model: "gpt-5.4-mini"}, }, backend.Config{}) - if err == nil || !strings.Contains(err.Error(), "max output tokens") { + if err == nil { + t.Fatal("expected error for max_output_tokens without compat extension") + } + if !strings.Contains(err.Error(), "max_output_tokens") { t.Fatalf("err = %v", err) } } -func TestPayloadForCall_ignoresAnthropicMandatoryMaxTokens(t *testing.T) { +func TestPayloadForCall_compatDropsMaxOutputTokens(t *testing.T) { t.Parallel() maxTok := 512 call := lipapi.Call{ @@ -244,10 +346,39 @@ func TestPayloadForCall_ignoresAnthropicMandatoryMaxTokens(t *testing.T) { Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}, }}, - Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, - Extensions: map[string]json.RawMessage{ - "anthropic.model": json.RawMessage(`"claude-3-5-haiku-20241022"`), - }, + Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, + Extensions: compatIgnoreGenParamsExt(), + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatalf("err = %v", err) + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + s := string(raw) + for _, key := range []string{`"max_output_tokens"`, `"max_tokens"`} { + if strings.Contains(s, key) { + t.Fatalf("payload must not emit %s: %s", key, s) + } + } +} + +func TestPayloadForCall_ignoresAnthropicMandatoryMaxTokensWithCompatExt(t *testing.T) { + t.Parallel() + maxTok := 512 + ext := compatIgnoreGenParamsExt() + ext["anthropic.model"] = json.RawMessage(`"claude-3-5-haiku-20241022"`) + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, + Extensions: ext, } payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ Primary: routing.Primary{Model: "gpt-5.4-mini"}, @@ -303,9 +434,38 @@ func TestPayloadForCall_rejectsTemperature(t *testing.T) { _, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ Primary: routing.Primary{Model: "gpt-5.4-mini"}, }, backend.Config{}) - if err == nil || !strings.Contains(err.Error(), "temperature") { + if err == nil { + t.Fatal("expected error for temperature without compat extension") + } + if !strings.Contains(err.Error(), "temperature") { + t.Fatalf("err = %v", err) + } +} + +func TestPayloadForCall_compatDropsTemperature(t *testing.T) { + t.Parallel() + temp := 0.2 + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + Options: lipapi.GenerationOptions{Temperature: &temp}, + Extensions: compatIgnoreGenParamsExt(), + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { t.Fatalf("err = %v", err) } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), `"temperature"`) { + t.Fatalf("payload must not emit temperature: %s", raw) + } } func TestPayloadForCall_rejectsTopP(t *testing.T) { @@ -321,11 +481,152 @@ func TestPayloadForCall_rejectsTopP(t *testing.T) { _, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ Primary: routing.Primary{Model: "gpt-5.4-mini"}, }, backend.Config{}) - if err == nil || !strings.Contains(err.Error(), "top_p") { + if err == nil { + t.Fatal("expected error for top_p without compat extension") + } + if !strings.Contains(err.Error(), "top_p") { t.Fatalf("err = %v", err) } } +func TestPayloadForCall_compatDropsTopP(t *testing.T) { + t.Parallel() + topP := 0.9 + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + Options: lipapi.GenerationOptions{TopP: &topP}, + Extensions: compatIgnoreGenParamsExt(), + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatalf("err = %v", err) + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), `"top_p"`) { + t.Fatalf("payload must not emit top_p: %s", raw) + } +} + +func TestPayloadForCall_stripsOpenAIProviderPrefix(t *testing.T) { + t.Parallel() + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "openai/gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + if payload.Model != "gpt-5.4-mini" { + t.Fatalf("model = %q, want %q (openai/ provider prefix must be stripped)", payload.Model, "gpt-5.4-mini") + } +} + +func TestPayloadForCall_foldsSystemMessagesIntoInstructions(t *testing.T) { + t.Parallel() + call := lipapi.Call{ + Instructions: []lipapi.Message{{Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("Base instruction.")}}}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("You are OpenCode. Use tools.")}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}, + }, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(payload.Instructions, "You are OpenCode. Use tools.") { + t.Fatalf("instructions must fold system message: %q", payload.Instructions) + } + if !strings.Contains(payload.Instructions, "Base instruction.") { + t.Fatalf("instructions must retain explicit instructions: %q", payload.Instructions) + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), `"role":"system"`) { + t.Fatalf("payload input must not contain a system role item: %s", raw) + } +} + +func TestPayloadForCall_systemMessageReplacesDefaultInstruction(t *testing.T) { + t.Parallel() + call := lipapi.Call{ + Messages: []lipapi.Message{ + {Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("You are my custom agent.")}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}, + }, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + if payload.Instructions != "You are my custom agent." { + t.Fatalf("instructions = %q, want %q (system message should replace default)", payload.Instructions, "You are my custom agent.") + } +} + +func TestPayloadForCall_doesNotDuplicateSystemMessageAlreadyInInstructions(t *testing.T) { + t.Parallel() + bridge := "OpenCode compatibility mode:\n- bridge block" + call := lipapi.Call{ + Instructions: []lipapi.Message{{Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("Base.\n\n" + bridge)}}}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart(bridge)}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}, + }, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + if got := strings.Count(payload.Instructions, "OpenCode compatibility mode"); got != 1 { + t.Fatalf("instructions must not duplicate bridge (count=%d): %q", got, payload.Instructions) + } +} + +func TestPayloadForCall_mergesSystemMessageThatIsSubstringOfInstructions(t *testing.T) { + t.Parallel() + call := lipapi.Call{ + Instructions: []lipapi.Message{{Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("Be concise and helpful.")}}}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("Be concise.")}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}, + }, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(payload.Instructions, "Be concise and helpful.") { + t.Fatalf("instructions must retain explicit block: %q", payload.Instructions) + } + if !strings.Contains(payload.Instructions, "Be concise.") { + t.Fatalf("instructions must merge system message even when substring of existing block: %q", payload.Instructions) + } +} + func TestPayloadForCall_defaultInstructionWhenEmpty(t *testing.T) { t.Parallel() call := lipapi.Call{ @@ -334,7 +635,7 @@ func TestPayloadForCall_defaultInstructionWhenEmpty(t *testing.T) { Parts: []lipapi.Part{lipapi.TextPart("hi")}, }}, } - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} payload, err := backend.PayloadForCall(&call, cand, backend.Config{}) if err != nil { t.Fatal(err) @@ -356,7 +657,7 @@ func TestPayloadForCall_configDefaultsWhenCallUnset(t *testing.T) { Parts: []lipapi.Part{lipapi.TextPart("hi")}, }}, } - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} payload, err := backend.PayloadForCall(&call, cand, backend.Config{ DefaultReasoningEffort: "low", }) @@ -384,7 +685,7 @@ func TestPayloadForCall_doesNotMutateForClientMarkers(t *testing.T) { Tools: []lipapi.ToolDef{{Name: "bash"}}, } payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err != nil { t.Fatal(err) @@ -394,7 +695,7 @@ func TestPayloadForCall_doesNotMutateForClientMarkers(t *testing.T) { } } -func TestPayloadForCall_nonHermesToolsRemainStrictAndParallelUnset(t *testing.T) { +func TestPayloadForCall_nonHermesToolsRemainStrictAndParallelFalse(t *testing.T) { t.Parallel() call := lipapi.Call{ Messages: []lipapi.Message{{ @@ -404,7 +705,7 @@ func TestPayloadForCall_nonHermesToolsRemainStrictAndParallelUnset(t *testing.T) Tools: []lipapi.ToolDef{{Name: "bash"}}, } payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err != nil { t.Fatal(err) @@ -413,8 +714,8 @@ func TestPayloadForCall_nonHermesToolsRemainStrictAndParallelUnset(t *testing.T) if !strings.Contains(string(raw), `"strict":true`) { t.Fatalf("expected strict=true for non-Hermes: %s", raw) } - if payload.ParallelToolCalls != nil { - t.Fatalf("expected parallel_tool_calls unset for non-Hermes: %+v", payload.ParallelToolCalls) + if payload.ParallelToolCalls == nil || *payload.ParallelToolCalls { + t.Fatalf("expected parallel_tool_calls=false for non-Hermes: %+v", payload.ParallelToolCalls) } } @@ -431,7 +732,7 @@ func TestPayloadForCall_hermesToolStrictFalseAndParallelTrue(t *testing.T) { }, } payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err != nil { t.Fatal(err) @@ -460,7 +761,7 @@ func TestPayloadForCall_hermesRespectsExplicitParallelFalse(t *testing.T) { }, } payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, backend.Config{}) if err != nil { t.Fatal(err) @@ -469,3 +770,208 @@ func TestPayloadForCall_hermesRespectsExplicitParallelFalse(t *testing.T) { t.Fatalf("expected explicit parallel_tool_calls=false honored for Hermes: %+v", payload.ParallelToolCalls) } } + +func TestPayloadForCall_mixedAssistantTextAndToolCall(t *testing.T) { + t.Parallel() + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{ + lipapi.TextPart("ok"), + { + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"call_abc","type":"function","function":{"name":"legacy_fn","arguments":"{}"}}`), + }, + }, + }}, + Tools: []lipapi.ToolDef{{Name: "legacy_fn", Parameters: json.RawMessage(`{"type":"object"}`)}}, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + raw, _ := json.Marshal(payload) + if !strings.Contains(string(raw), `"content":"ok"`) { + t.Fatalf("missing assistant text item: %s", raw) + } + if !strings.Contains(string(raw), `"type":"function_call"`) || !strings.Contains(string(raw), `"call_id":"call_abc"`) { + t.Fatalf("missing function call item: %s", raw) + } +} + +func TestPayloadForCall_normalizesMissingAdditionalPropertiesForStrict(t *testing.T) { + t.Parallel() + strict, raw := codexToolStrict(t, `{"type":"object","properties":{"patch":{"type":"string"}},"required":["patch"]}`) + if !strict { + t.Fatalf("schema missing only additionalProperties:false should be normalized to strict=true: %s", raw) + } +} + +func TestPayloadForCall_strictCompatibleToolSchemaUsesStrictTrue(t *testing.T) { + t.Parallel() + strict, raw := codexToolStrict(t, `{"type":"object","properties":{"patch":{"type":"string"}},"required":["patch"],"additionalProperties":false}`) + if !strict { + t.Fatalf("strict-compatible schema must keep strict=true: %s", raw) + } +} + +func TestPayloadForCall_parameterlessObjectSchemaGetsAdditionalPropertiesAndRequired(t *testing.T) { + t.Parallel() + strict, raw := codexToolStrict(t, `{"type":"object"}`) + if !strict { + t.Fatalf("parameterless object schema must stay strict=true after normalization: %s", raw) + } + var decoded struct { + Tools []struct { + Parameters map[string]any `json:"parameters"` + } `json:"tools"` + } + if err := json.Unmarshal(raw, &decoded); err != nil { + t.Fatal(err) + } + if len(decoded.Tools) != 1 { + t.Fatalf("tools: %v", decoded.Tools) + } + params := decoded.Tools[0].Parameters + if ap, ok := params["additionalProperties"].(bool); !ok || ap { + t.Fatalf("parameterless object must have additionalProperties:false: %#v", params) + } + req, ok := params["required"].([]any) + if !ok || len(req) != 0 { + t.Fatalf("parameterless object must have required:[]: %#v", params) + } +} + +func TestPayloadForCall_parameterlessObjectWithAdditionalPropertiesTrueIsStrictFalse(t *testing.T) { + t.Parallel() + // A parameterless object that explicitly allows additional properties is not + // strict-compatible; it must be sent strict:false so the upstream does not + // reject it. + strict, _ := codexToolStrict(t, `{"type":"object","additionalProperties":true}`) + if strict { + t.Fatal("parameterless object with additionalProperties:true must use strict=false") + } +} + +func TestPayloadForCall_noToolsOmitsExplicitParallelToolCalls(t *testing.T) { + t.Parallel() + // A no-tools turn must not leak any tool-protocol field, even when the client + // explicitly requests parallel_tool_calls. Keep the absence of tools explicit. + parallel := true + call := lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("continue")}, + }}, + Options: lipapi.GenerationOptions{ParallelToolCalls: ¶llel}, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + raw, err := json.Marshal(payload) + if err != nil { + t.Fatal(err) + } + if strings.Contains(string(raw), `"parallel_tool_calls"`) { + t.Fatalf("no-tools payload must omit parallel_tool_calls even when explicit: %s", raw) + } +} + +func TestPayloadForCall_schemaNormalizationSkipsNonSchemaKeywords(t *testing.T) { + t.Parallel() + // The "default" keyword holds an example value, not a subschema: it must not + // be mutated with additionalProperties:false/required:[] even though it looks + // like an object schema. + schema := `{"type":"object","properties":{"mode":{"type":"string","default":{"type":"object","properties":{"x":{"type":"string"}}}}},"required":["mode"],"additionalProperties":false}` + _, raw := codexToolStrict(t, schema) + var decoded struct { + Tools []struct { + Parameters map[string]any `json:"parameters"` + } `json:"tools"` + } + if err := json.Unmarshal(raw, &decoded); err != nil || len(decoded.Tools) != 1 { + t.Fatalf("decode tools: %v raw=%s", err, raw) + } + params := decoded.Tools[0].Parameters + props, _ := params["properties"].(map[string]any) + mode, _ := props["mode"].(map[string]any) + def, _ := mode["default"].(map[string]any) + if def == nil { + t.Fatalf("default value missing: %#v", params) + } + if _, ok := def["additionalProperties"]; ok { + t.Fatalf("default value must not be mutated with additionalProperties: %#v", def) + } + if _, ok := def["required"]; ok { + t.Fatalf("default value must not be mutated with required: %#v", def) + } +} + +func TestPayloadForCall_composedLooseToolSchemaUsesStrictFalse(t *testing.T) { + t.Parallel() + cases := []struct { + name string + schema string + }{ + { + name: "anyOf object missing required", + schema: `{"type":"object","properties":{"mode":{"anyOf":[{"type":"object","properties":{"path":{"type":"string"}},"additionalProperties":false}]}},"required":["mode"],"additionalProperties":false}`, + }, + { + name: "$ref is conservative false", + schema: `{"type":"object","properties":{"path":{"$ref":"#/$defs/path"}},"required":["path"],"additionalProperties":false,"$defs":{"path":{"type":"string"}}}`, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + strict, raw := codexToolStrict(t, tc.schema) + if strict { + t.Fatalf("composed/ambiguous schema must use strict=false: %s", raw) + } + }) + } +} + +func TestPayloadForCall_composedStrictToolSchemaUsesStrictTrue(t *testing.T) { + t.Parallel() + strict, raw := codexToolStrict(t, `{"type":"object","properties":{"mode":{"oneOf":[{"type":"object","properties":{"path":{"type":"string"}},"required":["path"],"additionalProperties":false}]}},"required":["mode"],"additionalProperties":false}`) + if !strict { + t.Fatalf("strict-compatible composed schema must keep strict=true: %s", raw) + } +} + +func codexToolStrict(t *testing.T, schema string) (bool, []byte) { + t.Helper() + call := lipapi.Call{ + Messages: []lipapi.Message{{Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("hi")}}}, + Tools: []lipapi.ToolDef{{ + Name: "apply_patch", + Parameters: json.RawMessage(schema), + }}, + } + payload, err := backend.PayloadForCall(&call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }, backend.Config{}) + if err != nil { + t.Fatal(err) + } + raw, _ := json.Marshal(payload) + var decoded struct { + Tools []struct { + Strict bool `json:"strict"` + } `json:"tools"` + } + if err := json.Unmarshal(raw, &decoded); err != nil { + t.Fatal(err) + } + if len(decoded.Tools) != 1 { + t.Fatalf("tools: %v", decoded.Tools) + } + return decoded.Tools[0].Strict, raw +} diff --git a/internal/plugins/backends/openaicodex/payload_tools.go b/internal/plugins/backends/openaicodex/payload_tools.go new file mode 100644 index 00000000..e60d594f --- /dev/null +++ b/internal/plugins/backends/openaicodex/payload_tools.go @@ -0,0 +1,74 @@ +package openaicodex + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "slices" + "sync" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +type toolPayload struct { + Type string `json:"type"` + Name string `json:"name"` + Description string `json:"description,omitempty"` + Parameters map[string]any `json:"parameters"` + Strict bool `json:"strict"` +} + +func buildTools(tools []lipapi.ToolDef, toolStrictDisabled bool) ([]toolPayload, error) { + out := make([]toolPayload, 0, len(tools)) + for _, t := range tools { + var schema map[string]any + if len(t.Parameters) > 0 { + if err := json.Unmarshal(t.Parameters, &schema); err != nil { + return nil, fmt.Errorf("%s: tool %q parameters: %w", ID, t.Name, err) + } + } + if schema == nil { + schema = map[string]any{} + } + schema, strict := normalizeToolSchemaForCodex(schema) + strict = strict && !toolStrictDisabled + if codexToolDebugEnabled() { + slog.Debug("openaicodex.tool_schema", "tool", t.Name, "strict", strict, "keys", sortedSchemaKeys(schema)) + } + out = append(out, toolPayload{ + Type: "function", + Name: t.Name, + Description: t.Description, + Parameters: schema, + Strict: strict, + }) + } + return out, nil +} + +func sortedSchemaKeys(schema map[string]any) []string { + keys := make([]string, 0, len(schema)) + for k := range schema { + keys = append(keys, k) + } + slices.Sort(keys) + return keys +} + +var ( + codexToolDebugEnabledValue = sync.OnceValue(func() bool { + return os.Getenv("LIP_CODEX_DEBUG_TOOLS") == "1" + }) + codexToolDeltaDebugEnabledValue = sync.OnceValue(func() bool { + return os.Getenv("LIP_CODEX_DEBUG_TOOL_DELTAS") == "1" + }) +) + +func codexToolDebugEnabled() bool { + return codexToolDebugEnabledValue() +} + +func codexToolDeltaDebugEnabled() bool { + return codexToolDeltaDebugEnabledValue() +} diff --git a/internal/plugins/backends/openaicodex/perf_benchmark_test.go b/internal/plugins/backends/openaicodex/perf_benchmark_test.go new file mode 100644 index 00000000..f1058272 --- /dev/null +++ b/internal/plugins/backends/openaicodex/perf_benchmark_test.go @@ -0,0 +1,100 @@ +package openaicodex + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func BenchmarkBuildToolsStrictSchema(b *testing.B) { + params := json.RawMessage(`{ + "type":"object", + "properties":{ + "command":{"type":"string"}, + "workdir":{"type":"string"}, + "timeout":{"type":"integer"}, + "env":{"type":"object","properties":{"PATH":{"type":"string"},"HOME":{"type":"string"}},"required":["PATH","HOME"]}, + "args":{"type":"array","items":{"type":"object","properties":{"name":{"type":"string"},"value":{"type":"string"}},"required":["name","value"]}} + }, + "required":["command","workdir","timeout","env","args"] + }`) + tools := []lipapi.ToolDef{{Name: "bash", Description: "run command", Parameters: params}} + b.ReportAllocs() + for b.Loop() { + if _, err := buildTools(tools, false); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkToolResultStringLargeOutput(b *testing.B) { + raw, err := json.Marshal(map[string]any{ + "output": strings.Repeat("line output\n", 4096), + "exit_code": 0, + "workdir": `C:\Users\Mateusz\source\repos\go-llm-interactive-proxy`, + "metadata": map[string]any{ + "ignored": strings.Repeat("nested data", 256), + }, + }) + if err != nil { + b.Fatal(err) + } + part := lipapi.Part{Kind: lipapi.PartToolResult, ToolCallID: "call_1", Content: raw} + b.ReportAllocs() + for b.Loop() { + if got := toolResultString(part); got == "" { + b.Fatal("empty result") + } + } +} + +func BenchmarkCodexEventMapperResponseCompleted(b *testing.B) { + data := `{"type":"response.completed","response":{"id":"resp_1","output":[{"type":"message","content":[{"type":"output_text","text":"done"}]},{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":100,"output_tokens":20,"total_tokens":120}}}` + b.ReportAllocs() + for b.Loop() { + mapper := newCodexEventMapper(0) + if err := mapper.handleData(data); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkWSContinuationPrepareRecord(b *testing.B) { + cfg := &Config{AccountID: "acct-1"} + call := lipapi.Call{ + Session: lipapi.SessionRef{ClientSessionID: "session-1"}, + Extensions: map[string]json.RawMessage{"agent": json.RawMessage(`"opencode"`)}, + } + base := benchmarkContinuationPayload(50) + next := base + next.Input = append(append([]inputItem(nil), base.Input...), textMessageItem{Type: "message", Role: "user", Content: "continue"}) + baseFP := fingerprintInputItems(base.Input) + nextFP := fingerprintInputItems(next.Input) + b.ReportAllocs() + for b.Loop() { + store := newWSContinuationStore(time.Minute, 8) + store.recordWithFingerprints(cfg, call, base, baseFP, "resp_1") + candidate := next + if !store.prepareWithFingerprints(context.Background(), cfg, call, &candidate, nextFP) { + b.Fatal("expected continuation") + } + } +} + +func benchmarkContinuationPayload(items int) Payload { + input := make([]inputItem, 0, items) + for i := range items { + input = append(input, textMessageItem{Type: "message", Role: "user", Content: strings.Repeat("context ", 16) + string(rune('a'+i%26))}) + } + return Payload{ + Model: "gpt-5.4-mini", + Instructions: "instructions", + PromptCacheKey: "session-1", + Input: input, + Tools: []toolPayload{{Type: "function", Name: "bash"}}, + } +} diff --git a/internal/plugins/backends/openaicodex/plugin.go b/internal/plugins/backends/openaicodex/plugin.go index 99a2288c..c2a1e20a 100644 --- a/internal/plugins/backends/openaicodex/plugin.go +++ b/internal/plugins/backends/openaicodex/plugin.go @@ -3,10 +3,13 @@ package openaicodex import ( "bytes" "context" + "errors" "fmt" + "log/slog" "net/http" "strings" "sync" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/execbackend" "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" @@ -16,6 +19,14 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk/modelinventory" ) +// errManagedAccountsExhausted signals that the managed WebSocket open failed because +// every managed account was unusable due to account-level auth/rate-limit rejection, +// not a WebSocket transport problem. openWithFallback uses it to skip the global WS +// cooldown: the bad accounts are already marked in the credpool and excluded from +// future selection, so disabling WS for the whole backend would only delay recovery +// of accounts whose per-account cooldown expires sooner than the WS fallback window. +var errManagedAccountsExhausted = errors.New("managed oauth accounts exhausted") + var backendCaps = lipapi.NewBackendCaps( lipapi.CapabilityStreaming, lipapi.CapabilityTools, @@ -26,11 +37,14 @@ var backendCaps = lipapi.NewBackendCaps( ) type backendRuntime struct { - mu sync.Mutex - cfg Config - oauth *accountStore - downgrade downgradePolicy - usageEst *usageEstimator + mu sync.Mutex + cfg Config + oauth *accountStore + downgrade downgradePolicy + usageEst *usageEstimator + cooldown *transportCooldown + wsSessions *wsSessionStore + continuation *wsContinuationStore } func New(cfg Config) execbackend.Backend { @@ -42,6 +56,14 @@ func New(cfg Config) execbackend.Backend { if err != nil { return newConfigErrorBackend(err) } + transport, err := NormalizeTransport(resolved.Transport, resolved.ExperimentalWebSocket) + if err != nil { + return newConfigErrorBackend(err) + } + resolved.Transport = transport + if resolved.WebSocketFallbackCooldown <= 0 { + resolved.WebSocketFallbackCooldown = DefaultWebSocketFallbackCooldown + } usageEst, err := newUsageEstimator() if err != nil { return newConfigErrorBackend(err) @@ -50,6 +72,9 @@ func New(cfg Config) execbackend.Backend { rt.cfg = resolved rt.oauth = store rt.usageEst = usageEst + rt.cooldown = newTransportCooldown(resolved.WebSocketFallbackCooldown) + rt.wsSessions = newWSSessionStore() + rt.continuation = newWSContinuationStore(codexContinuationTTL, codexContinuationMaxEntries) if store == nil { if err := checkcfg.RequireNonEmpty(ID, "access_token", resolved.AccessToken); err != nil { return newConfigErrorBackend(err) @@ -102,56 +127,147 @@ func (rt *backendRuntime) open(ctx context.Context, call lipapi.Call, cand routi cfg := rt.cfg store := rt.oauth usageEst := rt.usageEst + cooldown := rt.cooldown + wsSessions := rt.wsSessions + continuation := rt.continuation + downgrade := rt.downgrade rt.mu.Unlock() if store != nil { - return openManaged(ctx, &cfg, store, call, cand, rt.downgrade, usageEst) + return openWithFallback(ctx, &cfg, cooldown, + func() (lipapi.ManagedEventStream, error) { + return openManaged(ctx, &cfg, store, call, cand, downgrade, usageEst) + }, + func() (lipapi.ManagedEventStream, error) { + return openManagedWS(ctx, &cfg, store, call, cand, downgrade, usageEst, wsSessions, continuation) + }, + ) } - return open(ctx, &cfg, rt, call, cand) + return openWithFallback(ctx, &cfg, cooldown, + func() (lipapi.ManagedEventStream, error) { return openHTTP(ctx, &cfg, rt, downgrade, call, cand) }, + func() (lipapi.ManagedEventStream, error) { + return openWS(ctx, &cfg, downgrade, usageEst, wsSessions, continuation, call, cand) + }, + ) } -func openManaged(ctx context.Context, cfg *Config, store *accountStore, call lipapi.Call, cand routing.AttemptCandidate, policy downgradePolicy, usageEst *usageEstimator) (lipapi.ManagedEventStream, error) { +// openWithFallback orchestrates transport selection for both the static-token +// and managed paths. HTTPS is used directly when configured or when the WS +// cooldown is active; WebSocket is used strictly when configured; auto mode +// tries WS and falls back to HTTPS only on a WS fallback-eligible error, +// recording the cooldown. The openHTTPS/openWS closures carry the path-specific +// account wiring so this helper stays free of managed/static differences. +func openWithFallback( + ctx context.Context, + cfg *Config, + cooldown *transportCooldown, + openHTTPS, openWS func() (lipapi.ManagedEventStream, error), +) (lipapi.ManagedEventStream, error) { + switch cfg.Transport { + case TransportHTTPS: + return openHTTPS() + case TransportWebSocket: + return openWS() + default: + if cooldown.active() { + return openHTTPS() + } + es, err := openWS() + if err == nil { + return es, nil + } + // Account-level exhaustion from the managed WS path is not a WebSocket + // transport problem: the bad accounts are already marked and excluded, and + // HTTPS fallback may still succeed with a usable account. Skip the global WS + // cooldown so a later-recovered account can use WS again without waiting out + // the fallback window. + if errors.Is(err, errManagedAccountsExhausted) { + return openHTTPS() + } + if isWSFallbackError(ctx, err) { + cooldown.markFailed() + return openHTTPS() + } + return nil, err + } +} + +// selectManagedSession prepares the per-account session state shared by the WS +// and HTTP managed paths: picks an account for the conversation, derives the +// per-account call config, and resolves the plan-scoped model. The returned +// callCfg is a caller-owned copy so per-call mutation (e.g. OAuth refresh on +// the static path) never leaks back into the stored account config. +func selectManagedSession(env *codexOpenEnv, cfg *Config, store *accountStore, policy downgradePolicy) (managedAccount, Config, string, error) { + acct, err := store.selectAccountForSession(env.convID) + if err != nil { + return managedAccount{}, Config{}, "", err + } + callCfg := callCfgFromAccount(cfg, acct) + planType := firstNonEmpty(acct.PlanType, cfg.PlanTypeHint) + return acct, callCfg, policy.modelForPlan(env.originalModel, planType), nil +} + +type managedOpenAttemptFn func(ctx context.Context, env *codexOpenEnv, callCfg *Config, model string, usageEst *usageEstimator) (lipapi.ManagedEventStream, *http.Response, error) + +func openManagedAccountLoop(ctx context.Context, cfg *Config, store *accountStore, call lipapi.Call, cand routing.AttemptCandidate, policy downgradePolicy, usageEst *usageEstimator, attempt managedOpenAttemptFn) (lipapi.ManagedEventStream, error) { env, err := prepareCodexOpenEnv(ctx, cfg, call, cand, policy) if err != nil { return nil, err } retries := maxManagedRetries(store) for range retries { - acct, err := store.selectAccountForSession(env.convID) + acct, callCfg, model, err := selectManagedSession(env, cfg, store, policy) if err != nil { - return nil, err + return nil, fmt.Errorf("%s: no usable managed oauth accounts: %w", ID, errManagedAccountsExhausted) } - planType := firstNonEmpty(acct.PlanType, cfg.PlanTypeHint) - body, err := env.marshalWithModel(policy.modelForPlan(env.originalModel, planType)) + es, resp, err := attempt(ctx, env, &callCfg, model, usageEst) + if err == nil { + if resp != nil { + if qh := codexQuotaHeaders(resp.Header); len(qh) > 0 { + _ = store.persistQuotaHeaders(acct, qh) + } + } + return es, nil + } + if resp != nil { + switch resp.StatusCode { + case http.StatusUnauthorized, http.StatusForbidden: + store.markAuthInvalid(acct) + continue + case http.StatusTooManyRequests: + now := store.now() + store.markRateLimited(acct, credpool.CooldownFromRetryAfterOrFallback(resp.Header.Get("Retry-After"), now, store.fallback)) + continue + } + } + return nil, err + } + return nil, fmt.Errorf("%s: no usable managed oauth accounts: %w", ID, errManagedAccountsExhausted) +} + +func openManagedWS(ctx context.Context, cfg *Config, store *accountStore, call lipapi.Call, cand routing.AttemptCandidate, policy downgradePolicy, usageEst *usageEstimator, wsSessions *wsSessionStore, continuation *wsContinuationStore) (lipapi.ManagedEventStream, error) { + return openManagedAccountLoop(ctx, cfg, store, call, cand, policy, usageEst, func(ctx context.Context, env *codexOpenEnv, callCfg *Config, model string, usageEst *usageEstimator) (lipapi.ManagedEventStream, *http.Response, error) { + return openWSPrepared(ctx, env, callCfg, model, call, usageEst, wsSessions, continuation) + }) +} + +func openManaged(ctx context.Context, cfg *Config, store *accountStore, call lipapi.Call, cand routing.AttemptCandidate, policy downgradePolicy, usageEst *usageEstimator) (lipapi.ManagedEventStream, error) { + return openManagedAccountLoop(ctx, cfg, store, call, cand, policy, usageEst, func(ctx context.Context, env *codexOpenEnv, callCfg *Config, model string, usageEst *usageEstimator) (lipapi.ManagedEventStream, *http.Response, error) { + body, err := env.marshalWithModel(model) if err != nil { - return nil, err + return nil, nil, err } - callCfg := callCfgFromAccount(cfg, acct) attempt := env.newAttempt(ctx, cfg, call, body, usageEst) - resp, err := attempt.doRequest(&callCfg) + resp, err := attempt.doRequest(callCfg) if err != nil { - return nil, err + return nil, nil, err } switch resp.StatusCode { - case http.StatusUnauthorized, http.StatusForbidden: - readLimitedClose(resp) - store.markAuthInvalid(acct) - continue - case http.StatusTooManyRequests: - readLimitedClose(resp) - now := store.now() - store.markRateLimited(acct, credpool.CooldownFromRetryAfterOrFallback(resp.Header.Get("Retry-After"), now, store.fallback)) - continue - } - es, finalResp, err := completeCodexOpenAttempt(attempt, resp, &callCfg) - if err != nil { - return nil, err - } - if qh := codexQuotaHeaders(finalResp.Header); len(qh) > 0 { - _ = store.persistQuotaHeaders(acct, qh) + case http.StatusUnauthorized, http.StatusForbidden, http.StatusTooManyRequests: + b := readLimitedClose(resp) + return nil, resp, upstreamHTTPError(resp.StatusCode, b) } - return es, nil - } - return nil, fmt.Errorf("%s: no usable managed oauth accounts", ID) + return completeCodexOpenAttempt(attempt, resp, callCfg) + }) } func maxManagedRetries(store *accountStore) int { @@ -161,12 +277,12 @@ func maxManagedRetries(store *accountStore) int { return len(store.meta) } -func open(ctx context.Context, cfg *Config, rt *backendRuntime, call lipapi.Call, cand routing.AttemptCandidate) (lipapi.ManagedEventStream, error) { - env, err := prepareCodexOpenEnv(ctx, cfg, call, cand, rt.downgrade) +func openHTTP(ctx context.Context, cfg *Config, rt *backendRuntime, policy downgradePolicy, call lipapi.Call, cand routing.AttemptCandidate) (lipapi.ManagedEventStream, error) { + env, err := prepareCodexOpenEnv(ctx, cfg, call, cand, policy) if err != nil { return nil, err } - body, err := env.marshalWithModel(rt.downgrade.modelForPlan(env.originalModel, cfg.PlanTypeHint)) + body, err := env.marshalWithModel(policy.modelForPlan(env.originalModel, cfg.PlanTypeHint)) if err != nil { return nil, err } @@ -216,10 +332,33 @@ func doCodexRequest(ctx context.Context, client *http.Client, endpoint string, b return nil, fmt.Errorf("%s: build request: %w", ID, err) } applyCodexHeaders(req, *cfg, convID) + start := time.Now() + if debugTurnsEnabled() { + slog.DebugContext(ctx, "openaicodex.debug.http_request_start", + "endpoint", endpoint, + "body_bytes", len(body), + "conversation_id", convID, + ) + } resp, err := client.Do(req) if err != nil { + if debugTurnsEnabled() { + slog.DebugContext(ctx, "openaicodex.debug.http_request_done", + "endpoint", endpoint, + "duration_ms", time.Since(start).Milliseconds(), + "status", "error", + "error", err.Error(), + ) + } return nil, fmt.Errorf("%s: request: %w", ID, err) } + if debugTurnsEnabled() { + slog.DebugContext(ctx, "openaicodex.debug.http_request_done", + "endpoint", endpoint, + "duration_ms", time.Since(start).Milliseconds(), + "status", resp.StatusCode, + ) + } return resp, nil } @@ -271,18 +410,7 @@ var builtinCodexModelIDs = []string{ "gpt-5.5", "gpt-5.4", "gpt-5.4-mini", - "gpt-5.3-codex", - "gpt-5.2-codex", - "gpt-5.2", - "gpt-5.1-codex-max", - "gpt-5.1-codex", - "gpt-5.1-codex-mini", - "gpt-5.1", - "gpt-5-codex", - "gpt-5-codex-mini", - "gpt-5", - "gpt-oss-120b", - "gpt-oss-20b", + "gpt-5.3-codex-spark", } func newConfigErrorBackend(err error) execbackend.Backend { diff --git a/internal/plugins/backends/openaicodex/plugin_internal_test.go b/internal/plugins/backends/openaicodex/plugin_internal_test.go new file mode 100644 index 00000000..bdc9979d --- /dev/null +++ b/internal/plugins/backends/openaicodex/plugin_internal_test.go @@ -0,0 +1,108 @@ +package openaicodex + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func TestOpenWithFallback_managedExhaustionSkipsGlobalCooldown(t *testing.T) { + t.Parallel() + now := time.Time{} + cooldown := &transportCooldown{cooldown: time.Hour, now: func() time.Time { return now }} + cfg := &Config{Transport: TransportAuto} + var httpsCalls int + es, err := openWithFallback(context.Background(), cfg, cooldown, + func() (lipapi.ManagedEventStream, error) { + httpsCalls++ + return nil, errors.New("https failed") + }, + func() (lipapi.ManagedEventStream, error) { + return nil, fmt.Errorf("ws accounts exhausted: %w", errManagedAccountsExhausted) + }, + ) + if err == nil || es != nil { + t.Fatalf("expected https fallback error, got es=%v err=%v", es, err) + } + if httpsCalls != 1 { + t.Fatalf("httpsCalls = %d, want 1", httpsCalls) + } + if cooldown.active() { + t.Fatal("managed account exhaustion must not activate global WS cooldown") + } +} + +func TestOpenManagedAccountLoop_httpExhaustionClassifiesManagedAccountsExhausted(t *testing.T) { + t.Parallel() + dir := t.TempDir() + for _, name := range []string{"a.json", "b.json"} { + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte(`{"account_id":"`+name+`","access_token":"tok-`+name+`"}`), 0o600); err != nil { + t.Fatal(err) + } + } + store, err := newAccountStore(Config{ + ManagedOAuthEnabled: true, + ManagedOAuthStoragePath: dir, + ManagedOAuthSelectionStrategy: "first-available", + RateLimitFallback: time.Hour, + }) + if err != nil { + t.Fatal(err) + } + cfg := Config{BaseURL: "http://127.0.0.1/backend-api/codex/responses"} + _, err = openManagedAccountLoop(context.Background(), &cfg, store, lipapi.Call{}, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }, newDowngradePolicy(cfg), nil, func(_ context.Context, _ *codexOpenEnv, _ *Config, _ string, _ *usageEstimator) (lipapi.ManagedEventStream, *http.Response, error) { + return nil, &http.Response{StatusCode: http.StatusTooManyRequests, Header: make(http.Header)}, fmt.Errorf("account rate limited") + }) + if err == nil { + t.Fatal("expected exhaustion error") + } + if !errors.Is(err, errManagedAccountsExhausted) { + t.Fatalf("err = %v, want errManagedAccountsExhausted", err) + } +} + +func TestOpenWithFallback_transportFailureActivatesCooldown(t *testing.T) { + t.Parallel() + now := time.Time{} + cooldown := &transportCooldown{cooldown: time.Hour, now: func() time.Time { return now }} + cfg := &Config{Transport: TransportAuto} + _, _ = openWithFallback(context.Background(), cfg, cooldown, + func() (lipapi.ManagedEventStream, error) { return nil, errors.New("https ok path") }, + func() (lipapi.ManagedEventStream, error) { + return nil, newWSTransportError(errors.New("dial failed")) + }, + ) + if !cooldown.active() { + t.Fatal("transport WS failure must activate global WS cooldown") + } +} + +func TestOpenWithFallback_nonTransportFailureSkipsCooldown(t *testing.T) { + t.Parallel() + now := time.Time{} + cooldown := &transportCooldown{cooldown: time.Hour, now: func() time.Time { return now }} + cfg := &Config{Transport: TransportAuto} + _, err := openWithFallback(context.Background(), cfg, cooldown, + func() (lipapi.ManagedEventStream, error) { return nil, errors.New("https must not run") }, + func() (lipapi.ManagedEventStream, error) { + return nil, fmt.Errorf("%s: marshal payload: %w", ID, errors.New("bad field")) + }, + ) + if err == nil { + t.Fatal("expected non-transport error to propagate") + } + if cooldown.active() { + t.Fatal("programmer/data errors must not activate global WS cooldown") + } +} diff --git a/internal/plugins/backends/openaicodex/plugin_test.go b/internal/plugins/backends/openaicodex/plugin_test.go index 5b4e46f3..1a6dadbf 100644 --- a/internal/plugins/backends/openaicodex/plugin_test.go +++ b/internal/plugins/backends/openaicodex/plugin_test.go @@ -2,13 +2,18 @@ package openaicodex_test import ( "context" + "encoding/json" "errors" "io" "net/http" "net/http/httptest" + "slices" "strings" + "sync" "testing" + "time" + gorillawebsocket "github.com/gorilla/websocket" "github.com/matdev83/go-llm-interactive-proxy/internal/core/b2bua" "github.com/matdev83/go-llm-interactive-proxy/internal/core/execbackend" "github.com/matdev83/go-llm-interactive-proxy/internal/core/hooks" @@ -35,7 +40,7 @@ func TestNew_configErrors(t *testing.T) { t.Parallel() be := backend.New(tc.cfg) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected config error") @@ -52,7 +57,7 @@ func TestNew_missingAccessTokenConfigError(t *testing.T) { //nolint:paralleltest withHomeDir(t, t.TempDir()) be := backend.New(backend.Config{BaseURL: "http://127.0.0.1"}) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected config error") @@ -85,7 +90,7 @@ func TestOpen_refbackendHeadersAndEvents(t *testing.T) { Parts: []lipapi.Part{lipapi.TextPart("hello")}, }}, } - cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex"}} + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} es, err := be.Open(context.Background(), call, cand) if err != nil { t.Fatal(err) @@ -128,7 +133,7 @@ func TestOpen_baseURLEndingResponses(t *testing.T) { HTTPClient: ts.Client(), }) es, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -143,7 +148,7 @@ func TestOpen_nilContext(t *testing.T) { t.Parallel() be := backend.New(backend.Config{BaseURL: "http://127.0.0.1", AccessToken: "tok"}) _, err := be.Open(nil, sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil || !errors.Is(err, lipapi.ErrNilContext) { t.Fatalf("err = %v", err) @@ -163,9 +168,10 @@ func TestOpen_non2xxIncludesStatus(t *testing.T) { BaseURL: ts.URL + "/backend-api/codex", AccessToken: "sk-codex", HTTPClient: ts.Client(), + Transport: backend.TransportHTTPS, }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected error") @@ -196,7 +202,7 @@ func TestOpen_429ReturnsUpstreamErrorWithoutRefresh(t *testing.T) { HTTPClient: srv.Client(), }) _, err := be.Open(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err == nil { t.Fatal("expected rate limit error") @@ -260,11 +266,949 @@ func TestOpen_routeParamsReachCodexPayload(t *testing.T) { } } +// TestOpen_rejectsUnsupportedGenerationParamsWithoutCompatExt proves plain calls with +// generation options the Codex Responses API does not support fail explicitly. +func TestOpen_rejectsUnsupportedGenerationParamsWithoutCompatExt(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + maxTok := 512 + call := sampleCall() + call.Route.Selector = "openai-codex:gpt-5.4-mini" + call.Options = lipapi.GenerationOptions{MaxOutputTokens: &maxTok} + _, err = ex.Execute(context.Background(), &call) + if err == nil { + t.Fatal("expected error for unsupported max_output_tokens without compat extension") + } + if !strings.Contains(err.Error(), "max_output_tokens") { + t.Fatalf("err = %v", err) + } +} + +// TestOpen_compatDropsUnsupportedGenerationParamsFromClient proves that generation options the +// Codex Responses API does not support are dropped when the compat extension opts in. +func TestOpen_compatDropsUnsupportedGenerationParamsFromClient(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + maxTok := 512 + temp := 0.2 + topP := 0.9 + call := sampleCall() + call.Route.Selector = "openai-codex:gpt-5.4-mini" + call.Options = lipapi.GenerationOptions{ + MaxOutputTokens: &maxTok, + Temperature: &temp, + TopP: &topP, + } + call.Extensions = map[string]json.RawMessage{ + backend.ExtIgnoreUnsupportedGenParams: json.RawMessage(`true`), + } + es, err := ex.Execute(context.Background(), &call) + if err != nil { + t.Fatal(err) + } + if _, err := lipapi.Collect(context.Background(), es); err != nil { + t.Fatal(err) + } + body := srv.LatestRequest().Body + for _, key := range []string{"max_output_tokens", "max_tokens", "temperature", "top_p"} { + if _, ok := body[key]; ok { + t.Fatalf("upstream body must not include %s: %#v", key, body[key]) + } + } +} + +// TestOpen_stripsOpenAIProviderModelPrefix proves that a client using the OpenCode +// provider/model convention (e.g. "openai/gpt-5.4-mini") has the "openai/" prefix stripped +// before the model reaches the Codex upstream, which rejects org-prefixed model names. +func TestOpen_stripsOpenAIProviderModelPrefix(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + call := sampleCall() + call.Route.Selector = "openai-codex:openai/gpt-5.4-mini?reasoning_effort=low" + es, err := ex.Execute(context.Background(), &call) + if err != nil { + t.Fatal(err) + } + if _, err := lipapi.Collect(context.Background(), es); err != nil { + t.Fatal(err) + } + body := srv.LatestRequest().Body + if got, _ := body["model"].(string); got != "gpt-5.4-mini" { + t.Fatalf("upstream model = %q, want %q (openai/ prefix must be stripped): %#v", got, "gpt-5.4-mini", body) + } +} + +// TestOpen_looseToolSchemaSentStrictFalse proves at the executor level that a +// client tool whose JSON schema is not Codex-strict-compatible (missing +// additionalProperties:false, as OpenCode's apply_patch is) is forwarded with +// strict:false so the upstream Codex Responses API does not reject the request. +func TestOpen_normalizesToolSchemaAdditionalPropertiesForStrict(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + call := sampleCall() + call.Route.Selector = "openai-codex:gpt-5.4-mini" + call.Tools = []lipapi.ToolDef{{ + Name: "apply_patch", + Parameters: json.RawMessage(`{"type":"object","properties":{"patch":{"type":"string"}},"required":["patch"]}`), + }} + es, err := ex.Execute(context.Background(), &call) + if err != nil { + t.Fatal(err) + } + if _, err := lipapi.Collect(context.Background(), es); err != nil { + t.Fatal(err) + } + tools, ok := srv.LatestRequest().Body["tools"].([]any) + if !ok || len(tools) != 1 { + t.Fatalf("upstream tools: %#v", srv.LatestRequest().Body["tools"]) + } + tool, ok := tools[0].(map[string]any) + if !ok { + t.Fatalf("upstream tool[0]: %#v", tools[0]) + } + if name, _ := tool["name"].(string); name != "apply_patch" { + t.Fatalf("upstream tool name = %q, want apply_patch: %#v", name, tool) + } + if strict, _ := tool["strict"].(bool); !strict { + t.Fatalf("upstream apply_patch strict=false; normalizable schema should be strict: %#v", tool) + } + params, ok := tool["parameters"].(map[string]any) + if !ok { + t.Fatalf("upstream parameters: %#v", tool["parameters"]) + } + if ap, ok := params["additionalProperties"].(bool); !ok || ap { + t.Fatalf("additionalProperties = %#v, want false: %#v", params["additionalProperties"], params) + } +} + +// TestOpen_chatCompletionsToolCallRoundTrip proves at the executor level that an +// assistant tool call decoded from a Chat Completions request (PartJSON with +// type:"function" and a nested "function" object, plus a matching tool result) +// is translated into Codex input function_call + function_call_output items +// linked by the same call_id, instead of failing with "unsupported part kind". +func TestOpen_chatCompletionsToolCallRoundTrip(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "done"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + call := sampleCall() + call.Route.Selector = "openai-codex:gpt-5.4-mini" + call.Messages = []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("run it")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: []byte(`{"id":"call_abc","type":"function","function":{"name":"bash","arguments":"{\"command\":\"echo pong\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: []byte("pong\n"), + }}}, + } + call.Tools = []lipapi.ToolDef{{Name: "bash", Parameters: json.RawMessage(`{"type":"object"}`)}} + es, err := ex.Execute(context.Background(), &call) + if err != nil { + t.Fatalf("execute: %v", err) + } + if _, err := lipapi.Collect(context.Background(), es); err != nil { + t.Fatalf("collect: %v", err) + } + input, ok := srv.LatestRequest().Body["input"].([]any) + if !ok { + t.Fatalf("upstream input missing: %#v", srv.LatestRequest().Body["input"]) + } + var sawCall, sawOutput bool + for _, it := range input { + m, ok := it.(map[string]any) + if !ok { + continue + } + if m["type"] == "function_call" { + if m["call_id"] == "call_abc" && m["name"] == "bash" { + sawCall = true + } + } + if m["type"] == "function_call_output" && m["call_id"] == "call_abc" { + sawOutput = true + } + } + if !sawCall { + t.Fatalf("upstream input missing function_call(call_abc,bash): %#v", input) + } + if !sawOutput { + t.Fatalf("upstream input missing function_call_output(call_abc): %#v", input) + } +} + +// TestOpen_routeSelectorRoutesArbitraryCodexModels proves manual routing to arbitrary +// openai-codex models via the full route selector "openai-codex:?reasoning_effort=low". +// For each example model the request reaching the Codex backend carries that exact model and +// the reasoning_effort URI param converted into the payload reasoning.effort data structure. +func TestOpen_routeSelectorRoutesArbitraryCodexModels(t *testing.T) { + t.Parallel() + models := []string{"gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex-spark"} + for _, model := range models { + t.Run(model, func(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + secure, err := app.NewManager(memory.New(memory.Options{SimulateDurable: true}), app.NewRandGenerator([]byte("12345678901234567890123456789012")), b2bualineage.New(st), app.ManagerConfig{ + FingerprintKey: []byte("12345678901234567890123456789012"), + StoreDurable: true, + }) + if err != nil { + t.Fatal(err) + } + ex := &coreruntime.Executor{ + Store: st, + SecureSession: secure, + SyntheticLocalPrincipal: true, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(1), + Backends: map[string]execbackend.Backend{ + backend.ID: be, + }, + } + call := sampleCall() + call.Route.Selector = "openai-codex:" + model + "?reasoning_effort=low" + es, err := ex.Execute(context.Background(), &call) + if err != nil { + t.Fatal(err) + } + if _, err := lipapi.Collect(context.Background(), es); err != nil { + t.Fatal(err) + } + body := srv.LatestRequest().Body + if got, _ := body["model"].(string); got != model { + t.Fatalf("payload model %q, want %q (body: %#v)", got, model, body) + } + reasoning, ok := body["reasoning"].(map[string]any) + if !ok || reasoning["effort"] != "low" { + t.Fatalf("reasoning payload: %#v", body["reasoning"]) + } + }) + } +} + +func TestOpen_httpsDoesNotSendContinuation(t *testing.T) { + t.Parallel() + var mu sync.Mutex + var bodies []map[string]any + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/responses") { + http.NotFound(w, r) + return + } + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "bad json", http.StatusBadRequest) + return + } + mu.Lock() + bodies = append(bodies, body) + n := len(bodies) + mu.Unlock() + + w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") + if n == 1 { + _, _ = io.WriteString(w, `data: {"type":"response.created","response":{"id":"resp_1"}}`+"\n\n") + _, _ = io.WriteString(w, `data: {"type":"response.output_item.done","item":{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}}`+"\n\n") + _, _ = io.WriteString(w, `data: {"type":"response.completed","response":{"id":"resp_1","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`+"\n\n") + return + } + _, _ = io.WriteString(w, `data: {"type":"response.created","response":{"id":"resp_2"}}`+"\n\n") + _, _ = io.WriteString(w, `data: {"type":"response.completed","response":{"id":"resp_2","output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`+"\n\n") + })) + t.Cleanup(srv.Close) + + be := backend.New(backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + Transport: backend.TransportHTTPS, + }) + call := lipapi.Call{ + ID: "call_aaaaaaaaaaaaaaaa", + Session: lipapi.SessionRef{ClientSessionID: "sess-https-continuation"}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("inspect")}, + }}, + Tools: []lipapi.ToolDef{{Name: "read"}}, + } + es, err := be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + call.Messages = append(call.Messages, + lipapi.Message{Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"fc_1","call_id":"call_fc_1","type":"function_call","name":"read","arguments":"{\"filePath\":\"a.go\"}"}`), + }}}, + lipapi.Message{Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_fc_1", + Content: json.RawMessage(`{"content":"package main"}`), + }}}, + ) + es, err = be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + mu.Lock() + defer mu.Unlock() + if len(bodies) != 2 { + t.Fatalf("requests = %d, want 2", len(bodies)) + } + if _, ok := bodies[0]["previous_response_id"]; ok { + t.Fatalf("first request unexpectedly had previous_response_id: %#v", bodies[0]) + } + if _, ok := bodies[1]["previous_response_id"]; ok { + t.Fatalf("HTTPS request must not send previous_response_id: %#v", bodies[1]) + } + input, ok := bodies[1]["input"].([]any) + if !ok { + t.Fatalf("second input = %#v", bodies[1]["input"]) + } + if len(input) <= 1 { + t.Fatalf("HTTPS second request should be full replay, len=%d body=%#v", len(input), bodies[1]) + } +} + +func TestOpen_websocketContinuationSendsDeltaAndPreservesTools(t *testing.T) { + t.Parallel() + var mu sync.Mutex + var bodies []map[string]any + var conns []*gorillawebsocket.Conn + connCount := 0 + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !gorillawebsocket.IsWebSocketUpgrade(r) { + http.NotFound(w, r) + return + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + connCount++ + conns = append(conns, conn) + mu.Unlock() + defer func() { _ = conn.Close() }() + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + var body map[string]any + if err := json.Unmarshal(data, &body); err != nil { + return + } + mu.Lock() + bodies = append(bodies, body) + n := len(bodies) + mu.Unlock() + if n == 1 { + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.output_item.done","item":{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_1","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + continue + } + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_2"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_2","output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + } + })) + t.Cleanup(func() { + mu.Lock() + for _, conn := range conns { + _ = conn.Close() + } + mu.Unlock() + srv.CloseClientConnections() + srv.Close() + }) + + be := backend.New(backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + }) + call := lipapi.Call{ + ID: "call_bbbbbbbbbbbbbbbb", + Session: lipapi.SessionRef{ClientSessionID: "sess-ws-continuation"}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("inspect")}, + }}, + Tools: []lipapi.ToolDef{{Name: "read"}}, + } + es, err := be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + call.Messages = append(call.Messages, + lipapi.Message{Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"fc_1","call_id":"call_fc_1","type":"function_call","name":"read","arguments":"{\"filePath\":\"a.go\"}"}`), + }}}, + lipapi.Message{Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_fc_1", + Content: json.RawMessage(`{"content":"package main"}`), + }}}, + ) + es, err = be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + mu.Lock() + defer mu.Unlock() + if connCount != 1 { + t.Fatalf("websocket connections = %d, want one reused connection", connCount) + } + if len(bodies) != 2 { + t.Fatalf("requests = %d, want 2", len(bodies)) + } + if got, _ := bodies[1]["previous_response_id"].(string); got != "resp_1" { + t.Fatalf("second previous_response_id = %q, body=%#v", got, bodies[1]) + } + input, ok := bodies[1]["input"].([]any) + if !ok { + t.Fatalf("second input = %#v", bodies[1]["input"]) + } + if len(input) != 1 { + t.Fatalf("second input len = %d, body=%#v", len(input), bodies[1]) + } + item, ok := input[0].(map[string]any) + if !ok || item["type"] != "function_call_output" || item["call_id"] != "call_fc_1" { + t.Fatalf("second delta input = %#v", input) + } + tools, ok := bodies[1]["tools"].([]any) + if !ok || len(tools) != 1 { + t.Fatalf("WS continuation must preserve tools, got %#v", bodies[1]["tools"]) + } +} + +func TestOpen_websocketStateIsIsolatedPerBackendInstance(t *testing.T) { + t.Parallel() + var mu sync.Mutex + var bodies []map[string]any + var conns []*gorillawebsocket.Conn + connCount := 0 + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !gorillawebsocket.IsWebSocketUpgrade(r) { + http.NotFound(w, r) + return + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + connCount++ + conns = append(conns, conn) + mu.Unlock() + defer func() { _ = conn.Close() }() + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + var body map[string]any + if err := json.Unmarshal(data, &body); err != nil { + return + } + mu.Lock() + bodies = append(bodies, body) + n := len(bodies) + mu.Unlock() + switch n { + case 1: + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.output_item.done","item":{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_1","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + case 2: + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_2"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_2","output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + } + } + })) + t.Cleanup(func() { + mu.Lock() + for _, conn := range conns { + _ = conn.Close() + } + mu.Unlock() + srv.CloseClientConnections() + srv.Close() + }) + + cfg := backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + } + be1 := backend.New(cfg) + be2 := backend.New(cfg) + call := lipapi.Call{ + ID: "call_eeeeeeeeeeeeeeee", + Session: lipapi.SessionRef{ClientSessionID: "sess-ws-instance-isolation"}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("inspect")}, + }}, + Tools: []lipapi.ToolDef{{Name: "read"}}, + } + es, err := be1.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + call.Messages = append(call.Messages, + lipapi.Message{Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"fc_1","call_id":"call_fc_1","type":"function_call","name":"read","arguments":"{\"filePath\":\"a.go\"}"}`), + }}}, + lipapi.Message{Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_fc_1", + Content: json.RawMessage(`{"content":"package main"}`), + }}}, + ) + es, err = be2.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + mu.Lock() + defer mu.Unlock() + if connCount != 2 { + t.Fatalf("websocket connections = %d, want one connection per backend instance", connCount) + } + if len(bodies) != 2 { + t.Fatalf("requests = %d, want 2", len(bodies)) + } + if _, ok := bodies[1]["previous_response_id"]; ok { + t.Fatalf("second backend instance must not reuse first instance continuation: %#v", bodies[1]) + } + input, ok := bodies[1]["input"].([]any) + if !ok { + t.Fatalf("second input = %#v", bodies[1]["input"]) + } + if len(input) <= 1 { + t.Fatalf("second backend instance should send full history, len=%d body=%#v", len(input), bodies[1]) + } +} + +func TestOpen_websocketContinuationInvalidationRetriesFullPayload(t *testing.T) { + t.Parallel() + var mu sync.Mutex + var bodies []map[string]any + var conns []*gorillawebsocket.Conn + connCount := 0 + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !gorillawebsocket.IsWebSocketUpgrade(r) { + http.NotFound(w, r) + return + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + connCount++ + conns = append(conns, conn) + mu.Unlock() + defer func() { _ = conn.Close() }() + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + var body map[string]any + if err := json.Unmarshal(data, &body); err != nil { + return + } + mu.Lock() + bodies = append(bodies, body) + n := len(bodies) + mu.Unlock() + switch n { + case 1: + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.output_item.done","item":{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_1","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + case 2: + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_stale"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"previous response not found"}}`)) + return + case 3: + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_2"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_2","output":[{"type":"message","content":[{"type":"output_text","text":"ok"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + } + } + })) + t.Cleanup(func() { + mu.Lock() + for _, conn := range conns { + _ = conn.Close() + } + mu.Unlock() + srv.CloseClientConnections() + srv.Close() + }) + + be := backend.New(backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + call := lipapi.Call{ + ID: "call_dddddddddddddddd", + Session: lipapi.SessionRef{ClientSessionID: "sess-ws-continuation-invalid"}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("inspect")}, + }}, + Tools: []lipapi.ToolDef{{Name: "read"}}, + } + es, err := be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + call.Messages = append(call.Messages, + lipapi.Message{Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"fc_1","call_id":"call_fc_1","type":"function_call","name":"read","arguments":"{\"filePath\":\"a.go\"}"}`), + }}}, + lipapi.Message{Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_fc_1", + Content: json.RawMessage(`{"content":"package main"}`), + }}}, + ) + es, err = be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + mu.Lock() + defer mu.Unlock() + if connCount != 2 { + t.Fatalf("websocket connections = %d, want stale continuation retry on a fresh connection", connCount) + } + if len(bodies) != 3 { + t.Fatalf("requests = %d, want initial + stale continuation + full retry", len(bodies)) + } + if got, _ := bodies[1]["previous_response_id"].(string); got != "resp_1" { + t.Fatalf("stale continuation previous_response_id = %q, body=%#v", got, bodies[1]) + } + if _, ok := bodies[2]["previous_response_id"]; ok { + t.Fatalf("full retry must drop previous_response_id after invalidation: %#v", bodies[2]) + } + input, ok := bodies[2]["input"].([]any) + if !ok { + t.Fatalf("full retry input = %#v", bodies[2]["input"]) + } + if len(input) <= 1 { + t.Fatalf("full retry should replay history, len=%d body=%#v", len(input), bodies[2]) + } +} + +func TestOpen_websocketContinuationReturnsAfterFirstEvent(t *testing.T) { + t.Parallel() + // Regression for a live OpenCode stall: strict WebSocket mode once waited for + // committed output during continuation, so a fast response.created frame could + // sit behind a blocked tool-result read until the frontend timed out. Strict WS + // must return after the first canonical event and prepend it to the stream; only + // auto mode waits longer so it can still fall back to HTTPS before commitment. + var mu sync.Mutex + var bodies []map[string]any + var conns []*gorillawebsocket.Conn + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !gorillawebsocket.IsWebSocketUpgrade(r) { + http.NotFound(w, r) + return + } + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + conns = append(conns, conn) + mu.Unlock() + defer func() { _ = conn.Close() }() + for { + _, data, err := conn.ReadMessage() + if err != nil { + return + } + var body map[string]any + if err := json.Unmarshal(data, &body); err != nil { + return + } + mu.Lock() + bodies = append(bodies, body) + n := len(bodies) + mu.Unlock() + if n == 1 { + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.output_item.done","item":{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}}`)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.completed","response":{"id":"resp_1","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"read","arguments":"{\"filePath\":\"a.go\"}"}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}}`)) + continue + } + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_2"}}`)) + _, _, _ = conn.ReadMessage() + return + } + })) + t.Cleanup(func() { + mu.Lock() + for _, conn := range conns { + _ = conn.Close() + } + mu.Unlock() + srv.CloseClientConnections() + srv.Close() + }) + + be := backend.New(backend.Config{ + BaseURL: srv.URL + "/backend-api/codex", + AccessToken: "tok", + HTTPClient: srv.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + }) + call := lipapi.Call{ + ID: "call_cccccccccccccccc", + Session: lipapi.SessionRef{ClientSessionID: "sess-ws-first-event"}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("inspect")}, + }}, + Tools: []lipapi.ToolDef{{Name: "read"}}, + } + es, err := be.Open(context.Background(), call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + call.Messages = append(call.Messages, + lipapi.Message{Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"fc_1","call_id":"call_fc_1","type":"function_call","name":"read","arguments":"{\"filePath\":\"a.go\"}"}`), + }}}, + lipapi.Message{Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_fc_1", + Content: json.RawMessage(`{"content":"package main"}`), + }}}, + ) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + es, err = be.Open(ctx, call, routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.4-mini"}}) + if err != nil { + t.Fatalf("strict websocket continuation should return after response.created, got %v", err) + } + defer func() { _ = es.Close() }() + if ev, err := es.Recv(context.Background()); err != nil || ev.Kind != lipapi.EventResponseStarted { + t.Fatalf("first continuation event = (%v, %v), want response_started", ev.Kind, err) + } + + mu.Lock() + defer mu.Unlock() + if len(bodies) != 2 { + t.Fatalf("requests = %d, want 2", len(bodies)) + } + if got, _ := bodies[1]["previous_response_id"].(string); got != "resp_1" { + t.Fatalf("second previous_response_id = %q, body=%#v", got, bodies[1]) + } +} + func TestResolveCaps_returnsCodexBackendCaps(t *testing.T) { t.Parallel() be := backend.New(backend.Config{BaseURL: "http://127.0.0.1", AccessToken: "tok"}) caps := be.ResolveCaps(context.Background(), sampleCall(), routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) for _, cap := range []lipapi.Capability{ lipapi.CapabilityVision, @@ -285,15 +1229,13 @@ func TestModelInventory_builtinWhenNoneConfigured(t *testing.T) { if err != nil { t.Fatal(err) } - found := false + got := make([]string, 0, len(snap.Models)) for _, m := range snap.Models { - if m.NativeID == "gpt-5.3-codex" { - found = true - break - } + got = append(got, m.NativeID) } - if !found { - t.Fatalf("models: %+v", snap.Models) + want := []string{"gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.3-codex-spark"} + if !slices.Equal(got, want) { + t.Fatalf("builtin codex native IDs = %#v, want exactly %#v", got, want) } } diff --git a/internal/plugins/backends/openaicodex/prompt_cache_http_test.go b/internal/plugins/backends/openaicodex/prompt_cache_http_test.go index 73273287..7705f976 100644 --- a/internal/plugins/backends/openaicodex/prompt_cache_http_test.go +++ b/internal/plugins/backends/openaicodex/prompt_cache_http_test.go @@ -33,7 +33,7 @@ func TestOpen_payloadIncludesPromptCacheKeyFromSession(t *testing.T) { }}, } es, err := be.Open(context.Background(), call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) @@ -46,6 +46,48 @@ func TestOpen_payloadIncludesPromptCacheKeyFromSession(t *testing.T) { } } +func TestOpen_payloadUsesAuthoritativeSessionBeforeClientHint(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + call := lipapi.Call{ + ID: "call-fallback", + Session: lipapi.SessionRef{ + ClientSessionID: "client-controlled-session", + AuthoritativeSessionID: "proxy-authoritative-session", + }, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hello")}, + }}, + } + es, err := be.Open(context.Background(), call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + + got := srv.LatestRequest().Body["prompt_cache_key"] + if got != "proxy-authoritative-session" { + t.Fatalf("prompt_cache_key: %#v", got) + } + if got := srv.LatestRequest().ConversationID; got != "proxy-authoritative-session" { + t.Fatalf("conversation_id: %q", got) + } + if got := srv.LatestRequest().SessionID; got != "proxy-authoritative-session" { + t.Fatalf("session_id: %q", got) + } +} + func TestOpen_payloadPromptCacheKeyFallsBackToCallID(t *testing.T) { t.Parallel() srv := refbackend.New(refbackend.Config{Token: "sk-codex"}) @@ -65,7 +107,7 @@ func TestOpen_payloadPromptCacheKeyFallsBackToCallID(t *testing.T) { }}, } es, err := be.Open(context.Background(), call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }) if err != nil { t.Fatal(err) diff --git a/internal/plugins/backends/openaicodex/stream.go b/internal/plugins/backends/openaicodex/stream.go index d2887d06..8088e485 100644 --- a/internal/plugins/backends/openaicodex/stream.go +++ b/internal/plugins/backends/openaicodex/stream.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "strings" "sync" @@ -15,57 +16,30 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" ) -type codexStream struct { - mu sync.Mutex - body io.ReadCloser - scanner *bufio.Scanner - pending stream.PendingEventQueue - mapper *openairesponsestream.Mapper - closed bool -} - -func newCodexStream(body io.ReadCloser, maxPending int) *codexStream { - sc := bufio.NewScanner(body) - sc.Buffer(make([]byte, 0, 64*1024), 4<<20) - st := &codexStream{ - body: body, - scanner: sc, - pending: stream.NewPendingEventQueue(maxPending), - } - st.mapper = openairesponsestream.New(&st.pending) - return st -} - -func (s *codexStream) Recv(ctx context.Context) (lipapi.Event, error) { - pump := stream.EventPump[string]{ - Lock: &s.mu, - Pending: &s.pending, - IsClosed: func() bool { return s.closed }, - Read: s.readData, - Handle: s.handleData, - } - return pump.Recv(ctx) +// codexEventMapper holds the canonical-event mapping state shared by SSE and +// WebSocket transports. It is not concurrency-safe; callers must serialize +// handleData calls (the EventPump does this under its lock). +type codexEventMapper struct { + pending stream.PendingEventQueue + mapper *openairesponsestream.Mapper + responseID string + outputItems []inputItem + toolCallIDs map[string]string + provisional map[string]bool + terminal bool } -func (s *codexStream) readData() (string, bool, error) { - for s.scanner.Scan() { - line := strings.TrimSpace(s.scanner.Text()) - if !strings.HasPrefix(line, "data: ") { - continue - } - data := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) - if data == "" || data == "[DONE]" { - continue - } - return data, true, nil - } - if err := s.scanner.Err(); err != nil { - return "", false, fmt.Errorf("%s: read stream: %w", ID, err) - } - return "", false, nil +func newCodexEventMapper(maxPending int) *codexEventMapper { + m := &codexEventMapper{ + pending: stream.NewPendingEventQueue(maxPending), + toolCallIDs: make(map[string]string), + provisional: make(map[string]bool), + } + m.mapper = openairesponsestream.New(&m.pending) + return m } -func (s *codexStream) handleData(data string) error { +func (m *codexEventMapper) handleData(data string) error { var base struct { Type string `json:"type"` } @@ -74,37 +48,54 @@ func (s *codexStream) handleData(data string) error { } switch base.Type { case "response.created": - return s.handleResponseCreated(data) + return m.handleResponseCreated(data) case "response.output_text.delta": - return s.handleOutputTextDelta(data) + return m.handleOutputTextDelta(data) case "response.completed": - return s.handleResponseCompleted(data) + return m.handleResponseCompleted(data) case "error": - return s.handleStreamError(data) + return m.handleStreamError(data) case "response.output_item.added": - return s.handleOutputItemAdded(data) + return m.handleOutputItemAdded(data) case "response.function_call_arguments.delta": - return s.handleFunctionCallArgumentsDelta(data) + return m.handleFunctionCallArgumentsDelta(data) case "response.function_call_arguments.done": - return s.handleFunctionCallArgumentsDone(data) + return m.handleFunctionCallArgumentsDone(data) case "response.output_item.done": - return s.handleOutputItemDone(data) + return m.handleOutputItemDone(data) default: return nil } } -func (s *codexStream) handleOutputTextDelta(data string) error { +func (m *codexEventMapper) handleOutputTextDelta(data string) error { var ev struct { Delta string `json:"delta"` } if err := json.Unmarshal([]byte(data), &ev); err != nil { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } - return s.mapper.OutputTextDelta(ev.Delta) + if looksLikeToolProtocolText(ev.Delta) { + return m.mapper.StreamError("tool_protocol_text_leak", "upstream emitted tool-call protocol as text", "upstream emitted tool-call protocol as text") + } + return m.mapper.OutputTextDelta(ev.Delta) +} + +func looksLikeToolProtocolText(delta string) bool { + text := strings.TrimSpace(delta) + if text == "" { + return false + } + if strings.Contains(text, "to=functions.") || strings.Contains(text, "to=functions_") { + return true + } + if strings.HasPrefix(text, "{") && (strings.Contains(text, `"filePath"`) || strings.Contains(text, `"offset"`) || strings.Contains(text, `"limit"`)) { + return true + } + return false } -func (s *codexStream) handleResponseCreated(data string) error { +func (m *codexEventMapper) handleResponseCreated(data string) error { var ev struct { Response struct { ID string `json:"id"` @@ -113,59 +104,77 @@ func (s *codexStream) handleResponseCreated(data string) error { if err := json.Unmarshal([]byte(data), &ev); err != nil { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } - return s.mapper.ResponseCreated() + m.responseID = strings.TrimSpace(ev.Response.ID) + return m.mapper.ResponseCreated() } -func (s *codexStream) handleResponseCompleted(data string) error { +func (m *codexEventMapper) handleResponseCompleted(data string) error { var ev struct { - Response json.RawMessage `json:"response"` + Response completedResponse `json:"response"` } if err := json.Unmarshal([]byte(data), &ev); err != nil { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } - if err := s.mapper.BeginCompleted(); err != nil { + if err := m.mapper.BeginCompleted(); err != nil { return err } - if len(ev.Response) > 0 { - if !s.mapper.SawTextDelta() { - if text := outputTextFromCompleted(ev.Response); text != "" { - if err := s.mapper.CompletedTextFallback(text); err != nil { - return err - } - } - } - for _, fc := range functionCallsFromCompleted(ev.Response) { - if err := s.mapper.EmitCompletedToolCall( - openairesponsestream.ToolCallID(fc.ID, fc.CallID), - fc.Name, - fc.Arguments, - ); err != nil { + if id := strings.TrimSpace(ev.Response.ID); id != "" { + m.responseID = id + } + if !m.mapper.SawTextDelta() { + if text := ev.Response.outputText(); text != "" { + if err := m.mapper.CompletedTextFallback(text); err != nil { return err } } - if usage := usageFromCompleted(ev.Response); usage != nil { - if err := s.mapper.PushUsage(usage); err != nil { - return err - } + } + for _, item := range ev.Response.Output { + if item.Type != "function_call" { + continue + } + if err := m.mapper.EmitCompletedToolCall( + codexCanonicalToolCallID(item.ID, item.CallID), + item.Name, + item.Arguments, + ); err != nil { + return err } } - return s.mapper.ResponseFinished() + if usage := ev.Response.usageEvent(); usage != nil { + if err := m.mapper.PushUsage(usage); err != nil { + return err + } + } + if err := m.mapper.ResponseFinished(); err != nil { + return err + } + m.terminal = true + return nil } -func outputTextFromCompleted(raw json.RawMessage) string { - var resp struct { - Output []struct { - Content []struct { - Type string `json:"type"` - Text string `json:"text"` - } `json:"content"` - } `json:"output"` - } - if err := json.Unmarshal(raw, &resp); err != nil { - return "" - } +type completedResponse struct { + ID string `json:"id"` + Output []struct { + Type string `json:"type"` + ID string `json:"id"` + CallID string `json:"call_id"` + Name string `json:"name"` + Arguments string `json:"arguments"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + } `json:"output"` + Usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + TotalTokens int64 `json:"total_tokens"` + } `json:"usage"` +} + +func (r completedResponse) outputText() string { var b strings.Builder - for _, item := range resp.Output { + for _, item := range r.Output { for _, c := range item.Content { if c.Type == "output_text" { b.WriteString(c.Text) @@ -175,42 +184,25 @@ func outputTextFromCompleted(raw json.RawMessage) string { return b.String() } -type completedFunctionCall struct { - ID string - CallID string - Name string - Arguments string -} - -func functionCallsFromCompleted(raw json.RawMessage) []completedFunctionCall { - var resp struct { - Output []struct { - Type string `json:"type"` - ID string `json:"id"` - CallID string `json:"call_id"` - Name string `json:"name"` - Arguments string `json:"arguments"` - } `json:"output"` - } - if err := json.Unmarshal(raw, &resp); err != nil { +func (r completedResponse) usageEvent() *lipapi.Event { + u := r.Usage + if u.InputTokens == 0 && u.OutputTokens == 0 && u.TotalTokens == 0 { return nil } - out := make([]completedFunctionCall, 0) - for _, item := range resp.Output { - if item.Type != "function_call" { - continue - } - out = append(out, completedFunctionCall{ - ID: item.ID, - CallID: item.CallID, - Name: item.Name, - Arguments: item.Arguments, - }) - } - return out + return &lipapi.Event{ + Kind: lipapi.EventUsageDelta, + InputTokens: safecast.IntFromInt64Clamp(u.InputTokens), + OutputTokens: safecast.IntFromInt64Clamp(u.OutputTokens), + TotalTokens: safecast.IntFromInt64Clamp(u.TotalTokens), + Accounting: lipapi.UsageAccountingMetadata{ + Plane: lipapi.UsagePlaneProviderBillable, + Source: lipapi.UsageSourceProviderReported, + Authority: lipapi.UsageAuthorityAuthoritative, + }, + } } -func (s *codexStream) handleOutputItemDone(data string) error { +func (m *codexEventMapper) handleOutputItemDone(data string) error { var ev struct { Item struct { Type string `json:"type"` @@ -226,16 +218,51 @@ func (s *codexStream) handleOutputItemDone(data string) error { if ev.Item.Type != "function_call" { return nil } - return s.mapper.FinishToolCallArguments( - openairesponsestream.ToolCallID(ev.Item.ID, ev.Item.CallID), + m.rememberToolCallID(ev.Item.ID, ev.Item.CallID) + m.remapProvisionalToolCall(ev.Item.ID, ev.Item.CallID) + if item, ok := outputFunctionCallInputItem(ev.Item.Type, ev.Item.ID, ev.Item.CallID, ev.Item.Name, ev.Item.Arguments); ok { + m.outputItems = append(m.outputItems, item) + } + return m.mapper.FinishToolCallArguments( + codexCanonicalToolCallID(ev.Item.ID, ev.Item.CallID), ev.Item.Name, ev.Item.Arguments, ) } -func (s *codexStream) handleStreamError(data string) error { +func outputFunctionCallInputItem(itemType, id, callID, name, arguments string) (functionCallItem, bool) { + if itemType != "function_call" { + return functionCallItem{}, false + } + hadCallID := strings.TrimSpace(callID) != "" + callID = strings.TrimSpace(callID) + id = strings.TrimSpace(id) + if callID == "" { + callID = id + } + name = strings.TrimSpace(name) + if callID == "" || name == "" { + return functionCallItem{}, false + } + if strings.TrimSpace(arguments) == "" { + arguments = "{}" + } + item := functionCallItem{ + Type: "function_call", + CallID: callID, + Name: name, + Arguments: arguments, + } + if id != "" && hadCallID { + item.ID = id + } + return item, true +} + +func (m *codexEventMapper) handleStreamError(data string) error { var ev struct { Error *struct { + Code string `json:"code"` Message string `json:"message"` } `json:"error"` } @@ -243,13 +270,18 @@ func (s *codexStream) handleStreamError(data string) error { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } msg := "" + code := "" if ev.Error != nil { + code = ev.Error.Code msg = ev.Error.Message } - return s.mapper.StreamError("", msg, "upstream error") + if debugTurnsEnabled() { + slog.Debug("openaicodex.debug.upstream_error", "code", code, "message", msg) + } + return m.mapper.StreamError(code, msg, "upstream error") } -func (s *codexStream) handleOutputItemAdded(data string) error { +func (m *codexEventMapper) handleOutputItemAdded(data string) error { var ev struct { Item struct { Type string `json:"type"` @@ -264,10 +296,12 @@ func (s *codexStream) handleOutputItemAdded(data string) error { if ev.Item.Type != "function_call" { return nil } - return s.mapper.ToolCallAdded(openairesponsestream.ToolCallID(ev.Item.ID, ev.Item.CallID), ev.Item.Name) + m.rememberToolCallID(ev.Item.ID, ev.Item.CallID) + m.remapProvisionalToolCall(ev.Item.ID, ev.Item.CallID) + return m.mapper.ToolCallAdded(codexCanonicalToolCallID(ev.Item.ID, ev.Item.CallID), ev.Item.Name) } -func (s *codexStream) handleFunctionCallArgumentsDelta(data string) error { +func (m *codexEventMapper) handleFunctionCallArgumentsDelta(data string) error { var ev struct { ItemID string `json:"item_id"` CallID string `json:"call_id"` @@ -276,10 +310,13 @@ func (s *codexStream) handleFunctionCallArgumentsDelta(data string) error { if err := json.Unmarshal([]byte(data), &ev); err != nil { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } - return s.mapper.ToolCallArgsDelta(openairesponsestream.ToolCallID(ev.ItemID, ev.CallID), ev.Delta) + if codexToolDeltaDebugEnabled() { + slog.Debug("openaicodex.tool_args_delta", "item_id", ev.ItemID, "call_id", ev.CallID, "delta", truncateDebug(ev.Delta, 512)) + } + return m.mapper.ToolCallArgsDelta(m.toolCallID(ev.ItemID, ev.CallID), ev.Delta) } -func (s *codexStream) handleFunctionCallArgumentsDone(data string) error { +func (m *codexEventMapper) handleFunctionCallArgumentsDone(data string) error { var ev struct { ItemID string `json:"item_id"` CallID string `json:"call_id"` @@ -289,35 +326,118 @@ func (s *codexStream) handleFunctionCallArgumentsDone(data string) error { if err := json.Unmarshal([]byte(data), &ev); err != nil { return fmt.Errorf("%s: malformed stream event: %w", ID, err) } - return s.mapper.FinishToolCallArguments(openairesponsestream.ToolCallID(ev.ItemID, ev.CallID), ev.Name, ev.Arguments) + if codexToolDebugEnabled() { + slog.Debug("openaicodex.tool_args_done", "item_id", ev.ItemID, "call_id", ev.CallID, "name", ev.Name, "arguments", truncateDebug(ev.Arguments, 512)) + } + return m.mapper.FinishToolCallArguments(m.toolCallID(ev.ItemID, ev.CallID), ev.Name, ev.Arguments) } -func usageFromCompleted(raw json.RawMessage) *lipapi.Event { - var resp struct { - Usage struct { - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - TotalTokens int64 `json:"total_tokens"` - } `json:"usage"` - } - if err := json.Unmarshal(raw, &resp); err != nil { - return nil +func codexCanonicalToolCallID(itemID, callID string) string { + return openairesponsestream.ToolCallID(callID, itemID) +} + +func (m *codexEventMapper) rememberToolCallID(itemID, callID string) { + itemID = strings.TrimSpace(itemID) + callID = strings.TrimSpace(callID) + if itemID == "" || callID == "" { + return + } + m.toolCallIDs[itemID] = callID + // Once the real call_id is known, drop the provisional flag so toolCallID + // stops returning the item-only ID and all subsequent events canonicalize + // onto the call_id. + delete(m.provisional, itemID) +} + +// remapProvisionalToolCall moves any mapper state buffered under the +// provisional item-only ID onto the real call_id once it is learned. Without +// this, argument deltas that arrived before output_item.added stay buffered +// under the item ID while ToolCallAdded targets the call_id, fragmenting one +// logical tool call into two. +func (m *codexEventMapper) remapProvisionalToolCall(itemID, callID string) { + itemID = strings.TrimSpace(itemID) + callID = strings.TrimSpace(callID) + if itemID == "" || callID == "" || callID == itemID { + return + } + m.mapper.RemapToolCallID(itemID, callID) +} + +func (m *codexEventMapper) toolCallID(itemID, callID string) string { + itemID = strings.TrimSpace(itemID) + callID = strings.TrimSpace(callID) + // Prefer a learned call_id over the provisional item-only ID so deltas and + // completion events resolve to the same canonical ID as output_item.added. + if callID == "" { + callID = strings.TrimSpace(m.toolCallIDs[itemID]) } - u := resp.Usage - if u.InputTokens == 0 && u.OutputTokens == 0 && u.TotalTokens == 0 { - return nil + if callID != "" { + return codexCanonicalToolCallID(itemID, callID) } - return &lipapi.Event{ - Kind: lipapi.EventUsageDelta, - InputTokens: safecast.IntFromInt64Clamp(u.InputTokens), - OutputTokens: safecast.IntFromInt64Clamp(u.OutputTokens), - TotalTokens: safecast.IntFromInt64Clamp(u.TotalTokens), - Accounting: lipapi.UsageAccountingMetadata{ - Plane: lipapi.UsagePlaneProviderBillable, - Source: lipapi.UsageSourceProviderReported, - Authority: lipapi.UsageAuthorityAuthoritative, - }, + if itemID != "" && m.provisional[itemID] { + return itemID + } + if callID == "" && itemID != "" { + m.provisional[itemID] = true + } + return codexCanonicalToolCallID(itemID, callID) +} + +func truncateDebug(s string, max int) string { + if len(s) <= max { + return s + } + return s[:max] + "..." +} + +var _ lipapi.ManagedEventStream = (*codexStream)(nil) + +type codexStream struct { + mapper *codexEventMapper + mu sync.Mutex + body io.ReadCloser + scanner *bufio.Scanner + closed bool +} + +func newCodexStream(body io.ReadCloser, maxPending int) *codexStream { + sc := bufio.NewScanner(body) + sc.Buffer(make([]byte, 0, 64*1024), 4<<20) + st := &codexStream{ + mapper: newCodexEventMapper(maxPending), + body: body, + scanner: sc, + } + return st +} + +func (s *codexStream) Recv(ctx context.Context) (lipapi.Event, error) { + pump := stream.EventPump[string]{ + Lock: &s.mu, + Pending: &s.mapper.pending, + IsClosed: func() bool { return s.closed }, + Read: s.readData, + Handle: s.mapper.handleData, } + return pump.Recv(ctx) +} + +func (s *codexStream) readData() (string, bool, error) { + for s.scanner.Scan() { + line := strings.TrimSpace(s.scanner.Text()) + if !strings.HasPrefix(line, "data: ") { + continue + } + data := strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + if data == "" || data == "[DONE]" { + continue + } + return data, true, nil + } + if err := s.scanner.Err(); err != nil { + return "", false, fmt.Errorf("%s: read stream: %w", ID, err) + } + return "", false, nil } func (s *codexStream) Close() error { diff --git a/internal/plugins/backends/openaicodex/stream_internal_test.go b/internal/plugins/backends/openaicodex/stream_internal_test.go index ac8b01fe..5fbcd89b 100644 --- a/internal/plugins/backends/openaicodex/stream_internal_test.go +++ b/internal/plugins/backends/openaicodex/stream_internal_test.go @@ -17,7 +17,7 @@ func testCodexStream() *codexStream { func TestHandleData_malformedJSON_returnsError(t *testing.T) { t.Parallel() s := testCodexStream() - if err := s.handleData("{not json"); err == nil { + if err := s.mapper.handleData("{not json"); err == nil { t.Fatal("expected malformed JSON error") } } @@ -38,11 +38,11 @@ func TestHandleData_responseCreatedAndCompleted_mapsLifecycleAndUsage(t *testing created := `{"type":"response.created","response":{"id":"resp_created"}}` completed := `{"type":"response.completed","response":{"id":"resp_completed","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}}` for _, raw := range []string{created, completed} { - if err := s.handleData(raw); err != nil { + if err := s.mapper.handleData(raw); err != nil { t.Fatalf("handleData: %v", err) } } - events := stream.DrainPending(&s.pending) + events := stream.DrainPending(&s.mapper.pending) want := []lipapi.EventKind{lipapi.EventResponseStarted, lipapi.EventMessageStarted, lipapi.EventUsageDelta, lipapi.EventResponseFinished} if len(events) != len(want) { t.Fatalf("events: %+v", events) @@ -71,11 +71,11 @@ func TestHandleData_completedWithoutUsageDoesNotEstimateUsage(t *testing.T) { delta := `{"type":"response.output_text.delta","delta":"world"}` completed := `{"type":"response.completed","response":{"id":"resp_1","status":"completed","output":[{"type":"message","content":[{"type":"output_text","text":"world"}]}]}}` for _, raw := range []string{delta, completed} { - if err := s.handleData(raw); err != nil { + if err := s.mapper.handleData(raw); err != nil { t.Fatalf("handleData: %v", err) } } - events := stream.DrainPending(&s.pending) + events := stream.DrainPending(&s.mapper.pending) for _, ev := range events { if ev.Kind == lipapi.EventUsageDelta { t.Fatalf("raw stream must not estimate usage: %+v", events) @@ -101,7 +101,7 @@ func TestUsageEstimatingStream_completedWithoutUsage_generatesEstimatedUsageBefo {Kind: lipapi.EventTextDelta, Delta: "world"}, {Kind: lipapi.EventResponseFinished}, }) - s := newUsageEstimatingStream(base, est, call, "gpt-5.3-codex") + s := newUsageEstimatingStream(base, est, call, "gpt-5.3-codex-spark") var events []lipapi.Event for { @@ -164,7 +164,7 @@ func TestUsageEstimatingStream_providerUsageIsNotOverridden(t *testing.T) { providerUsage, {Kind: lipapi.EventResponseFinished}, }) - s := newUsageEstimatingStream(base, est, lipapi.Call{}, "gpt-5.3-codex") + s := newUsageEstimatingStream(base, est, lipapi.Call{}, "gpt-5.3-codex-spark") var usage []lipapi.Event for { @@ -201,15 +201,19 @@ func TestHandleData_toolCallStream_mapsToCanonicalToolEvents(t *testing.T) { `{"type":"response.function_call_arguments.done","sequence_number":3,"item_id":"fc_1","output_index":0,"name":"get_weather","arguments":"{\"city\":\"NYC\"}"}`, } for _, raw := range rawEvents { - if err := s.handleData(raw); err != nil { + if err := s.mapper.handleData(raw); err != nil { t.Fatalf("handleData: %v", err) } } var kinds []lipapi.EventKind var args strings.Builder - for _, ev := range stream.DrainPending(&s.pending) { + var toolID string + for _, ev := range stream.DrainPending(&s.mapper.pending) { kinds = append(kinds, ev.Kind) + if ev.Kind == lipapi.EventToolCallStarted { + toolID = ev.ToolCallID + } if ev.Kind == lipapi.EventToolCallArgsDelta { args.WriteString(ev.Delta) } @@ -220,6 +224,9 @@ func TestHandleData_toolCallStream_mapsToCanonicalToolEvents(t *testing.T) { if got := args.String(); got != `{"city":"NYC"}` { t.Fatalf("combined args: %q", got) } + if toolID != "call_fc_1" { + t.Fatalf("tool call id = %q, want upstream call_id", toolID) + } if kinds[len(kinds)-1] != lipapi.EventToolCallFinished { t.Fatalf("last event: %v", kinds) } @@ -230,12 +237,12 @@ func TestHandleData_completedOnly_emitsFullText(t *testing.T) { s := testCodexStream() completed := `{"type":"response.completed","response":{"id":"resp_1","status":"completed","output":[{"type":"message","content":[{"type":"output_text","text":"done"}]}]}}` - if err := s.handleData(completed); err != nil { + if err := s.mapper.handleData(completed); err != nil { t.Fatal(err) } var texts []string - for _, ev := range stream.DrainPending(&s.pending) { + for _, ev := range stream.DrainPending(&s.mapper.pending) { if ev.Kind == lipapi.EventTextDelta { texts = append(texts, ev.Delta) } @@ -245,18 +252,46 @@ func TestHandleData_completedOnly_emitsFullText(t *testing.T) { } } +func TestHandleData_blocksToolProtocolTextLeak(t *testing.T) { + t.Parallel() + s := testCodexStream() + + raw := `{"type":"response.output_text.delta","delta":"{\"filePath\":\"C:\\\\repo\\\\file.go\",\"offset\":49,\"limit\":120}to=functions.read"}` + if err := s.mapper.handleData(raw); err != nil { + t.Fatal(err) + } + + events := stream.DrainPending(&s.mapper.pending) + if len(events) == 0 { + t.Fatal("expected error event") + } + for _, ev := range events { + if ev.Kind == lipapi.EventTextDelta { + t.Fatalf("tool protocol leaked as text: %+v", ev) + } + } + last := events[len(events)-1] + if last.Kind != lipapi.EventError || last.ErrorCode != "tool_protocol_text_leak" { + t.Fatalf("last event = %+v, want tool protocol leak error", last) + } +} + func TestHandleData_completed_replaysFunctionCalls(t *testing.T) { t.Parallel() s := testCodexStream() completed := `{"type":"response.completed","response":{"id":"resp_1","status":"completed","output":[{"type":"function_call","id":"fc_1","call_id":"call_fc_1","name":"get_weather","arguments":"{\"city\":\"NYC\"}"}]}}` - if err := s.handleData(completed); err != nil { + if err := s.mapper.handleData(completed); err != nil { t.Fatal(err) } var kinds []lipapi.EventKind - for _, ev := range stream.DrainPending(&s.pending) { + var toolID string + for _, ev := range stream.DrainPending(&s.mapper.pending) { kinds = append(kinds, ev.Kind) + if ev.Kind == lipapi.EventToolCallStarted { + toolID = ev.ToolCallID + } } want := []lipapi.EventKind{ lipapi.EventResponseStarted, @@ -274,6 +309,9 @@ func TestHandleData_completed_replaysFunctionCalls(t *testing.T) { t.Fatalf("event[%d] = %v, want %v", i, kinds[i], kind) } } + if toolID != "call_fc_1" { + t.Fatalf("tool call id = %q, want upstream call_id", toolID) + } } func TestHandleData_outputItemDone_emitsCompleteToolCall(t *testing.T) { @@ -281,13 +319,17 @@ func TestHandleData_outputItemDone_emitsCompleteToolCall(t *testing.T) { s := testCodexStream() raw := `{"type":"response.output_item.done","output_index":0,"item":{"type":"function_call","id":"fc_done","call_id":"call_fc_done","name":"get_weather","arguments":"{\"city\":\"NYC\"}"}}` - if err := s.handleData(raw); err != nil { + if err := s.mapper.handleData(raw); err != nil { t.Fatal(err) } var kinds []lipapi.EventKind - for _, ev := range stream.DrainPending(&s.pending) { + var toolID string + for _, ev := range stream.DrainPending(&s.mapper.pending) { kinds = append(kinds, ev.Kind) + if ev.Kind == lipapi.EventToolCallStarted { + toolID = ev.ToolCallID + } } want := []lipapi.EventKind{ lipapi.EventResponseStarted, @@ -304,6 +346,19 @@ func TestHandleData_outputItemDone_emitsCompleteToolCall(t *testing.T) { t.Fatalf("event[%d] = %v, want %v", i, kinds[i], kind) } } + if toolID != "call_fc_done" { + t.Fatalf("tool call id = %q, want upstream call_id", toolID) + } + if len(s.mapper.outputItems) != 1 { + t.Fatalf("output items = %#v", s.mapper.outputItems) + } + item, ok := s.mapper.outputItems[0].(functionCallItem) + if !ok { + t.Fatalf("output item = %#v, want functionCallItem", s.mapper.outputItems[0]) + } + if item.ID != "fc_done" || item.CallID != "call_fc_done" || item.Name != "get_weather" || item.Arguments != `{"city":"NYC"}` { + t.Fatalf("output item = %#v", item) + } } func TestHandleData_toolCallStream_callIDOnDelta(t *testing.T) { @@ -316,13 +371,13 @@ func TestHandleData_toolCallStream_callIDOnDelta(t *testing.T) { `{"type":"response.function_call_arguments.done","sequence_number":2,"call_id":"call_only","output_index":0,"name":"get_weather","arguments":"{\"x\":1}"}`, } for _, raw := range rawEvents { - if err := s.handleData(raw); err != nil { + if err := s.mapper.handleData(raw); err != nil { t.Fatalf("handleData: %v", err) } } var toolID string - for _, ev := range stream.DrainPending(&s.pending) { + for _, ev := range stream.DrainPending(&s.mapper.pending) { if ev.Kind == lipapi.EventToolCallStarted { toolID = ev.ToolCallID } @@ -331,3 +386,53 @@ func TestHandleData_toolCallStream_callIDOnDelta(t *testing.T) { t.Fatalf("tool call id: %q", toolID) } } + +func TestHandleData_toolArgsBeforeAddedWaitForToolName(t *testing.T) { + t.Parallel() + s := testCodexStream() + + rawEvents := []string{ + `{"type":"response.function_call_arguments.delta","sequence_number":1,"item_id":"fc_late","output_index":0,"delta":"{\"filePath\":"}`, + `{"type":"response.output_item.added","sequence_number":2,"output_index":0,"item":{"type":"function_call","id":"fc_late","call_id":"call_late","status":"in_progress","name":"read"}}`, + `{"type":"response.function_call_arguments.done","sequence_number":3,"item_id":"fc_late","output_index":0,"name":"read","arguments":"{\"filePath\":\"x\"}"}`, + } + for _, raw := range rawEvents { + if err := s.mapper.handleData(raw); err != nil { + t.Fatalf("handleData: %v", err) + } + } + + events := stream.DrainPending(&s.mapper.pending) + var startedCount int + var startedID, startedName string + var args strings.Builder + var finishedIDs []string + for i := range events { + ev := events[i] + switch ev.Kind { + case lipapi.EventToolCallStarted: + startedCount++ + startedID = ev.ToolCallID + startedName = ev.ToolName + case lipapi.EventToolCallArgsDelta: + args.WriteString(ev.Delta) + case lipapi.EventToolCallFinished: + finishedIDs = append(finishedIDs, ev.ToolCallID) + } + } + if startedCount != 1 { + t.Fatalf("tool call started count = %d, want 1 (no provisional/real duplicate); events=%+v", startedCount, events) + } + if startedID != "call_late" { + t.Fatalf("tool call started id = %q, want call_late; events=%+v", startedID, events) + } + if startedName != "read" { + t.Fatalf("tool started name = %q, want read; events=%+v", startedName, events) + } + if got := args.String(); got != `{"filePath":` { + t.Fatalf("args = %q, want incremental delta preserved after remap onto call_id", got) + } + if len(finishedIDs) != 1 || finishedIDs[0] != "call_late" { + t.Fatalf("tool call finished ids = %v, want [call_late]; events=%+v", finishedIDs, events) + } +} diff --git a/internal/plugins/backends/openaicodex/toolschema.go b/internal/plugins/backends/openaicodex/toolschema.go new file mode 100644 index 00000000..bf805511 --- /dev/null +++ b/internal/plugins/backends/openaicodex/toolschema.go @@ -0,0 +1,173 @@ +package openaicodex + +// isStrictCompatibleSchema reports whether a JSON schema satisfies the Codex +// Responses API "strict" tool-schema requirements: every object must declare +// additionalProperties:false and list all of its properties in required. Schemas +// that do not comply must be sent with strict:false, otherwise the upstream +// rejects the request (e.g. "additionalProperties is required to be supplied +// and to be false"). Parameterless object schemas are normalized by +// addStrictAdditionalProperties to include additionalProperties:false (and an +// empty required list) so they remain strict-compatible instead of leaking a +// strict:true tool that the upstream rejects. The check is conservative: when +// in doubt it returns false, which only relaxes strict mode (safe) and never +// causes an upstream rejection. +func isStrictCompatibleSchema(schema map[string]any) bool { + if hasRef(schema) || !strictCompatibleCompositions(schema) { + return false + } + if !isObjectSchema(schema) { + // Non-object root (array/primitive): only its array items must comply. + return strictCompatibleArrayItems(schema) + } + ap, ok := schema["additionalProperties"] + if !ok { + return false + } + if asBool, ok := ap.(bool); !ok || asBool { + return false + } + props, _ := schema["properties"].(map[string]any) + reqRaw, _ := schema["required"].([]any) + required := make(map[string]bool, len(reqRaw)) + for _, r := range reqRaw { + if s, ok := r.(string); ok { + required[s] = true + } + } + for name, prop := range props { + if !required[name] { + return false + } + child, ok := prop.(map[string]any) + if !ok { + return false + } + if !isStrictCompatibleChild(child) { + return false + } + } + return strictCompatibleArrayItems(schema) +} + +func normalizeToolSchemaForCodex(schema map[string]any) (map[string]any, bool) { + addStrictAdditionalProperties(schema) + return schema, isStrictCompatibleSchema(schema) +} + +// isObjectSchema reports whether a node is a JSON object schema: either it +// declares type "object" or it carries a non-empty properties map. Empty +// properties maps without an explicit object type are not treated as objects so +// that a truly empty schema ({}) is left untouched. +func isObjectSchema(node map[string]any) bool { + if t, _ := node["type"].(string); t == "object" { + return true + } + props, ok := node["properties"].(map[string]any) + return ok && len(props) > 0 +} + +func addStrictAdditionalProperties(v any) { + switch x := v.(type) { + case map[string]any: + // Only descend through subschema-bearing keywords so non-schema payloads + // (enum, default, examples, etc.) are not mutated with strict keywords. + if props, ok := x["properties"].(map[string]any); ok { + for k, child := range props { + addStrictAdditionalProperties(child) + props[k] = child + } + } + if items, ok := x["items"]; ok { + addStrictAdditionalProperties(items) + x["items"] = items + } + for _, key := range []string{"oneOf", "anyOf", "allOf"} { + if children, ok := x[key].([]any); ok { + for i, child := range children { + addStrictAdditionalProperties(child) + children[i] = child + } + } + } + if isObjectSchema(x) { + if _, ok := x["additionalProperties"]; !ok { + x["additionalProperties"] = false + } + // Parameterless objects must also carry an explicit required:[] for the + // Responses API strict mode; inject it only when no properties and no + // required are already declared. + if props, _ := x["properties"].(map[string]any); len(props) == 0 { + if _, ok := x["required"]; !ok { + x["required"] = []any{} + } + } + } + case []any: + for i, child := range x { + addStrictAdditionalProperties(child) + x[i] = child + } + } +} + +func isStrictCompatibleChild(node map[string]any) bool { + if hasRef(node) || !strictCompatibleCompositions(node) { + return false + } + if props, ok := node["properties"].(map[string]any); ok && len(props) > 0 { + return isStrictCompatibleSchema(node) + } + switch t, _ := node["type"].(string); t { + case "object": + return isStrictCompatibleSchema(node) + case "array": + return strictCompatibleArrayItems(node) + } + return true +} + +func hasRef(node map[string]any) bool { + _, ok := node["$ref"] + return ok +} + +func strictCompatibleCompositions(node map[string]any) bool { + for _, key := range []string{"oneOf", "anyOf", "allOf"} { + raw, ok := node[key] + if !ok { + continue + } + items, ok := raw.([]any) + if !ok || len(items) == 0 { + return false + } + for _, item := range items { + child, ok := item.(map[string]any) + if !ok || !isStrictCompatibleChild(child) { + return false + } + } + } + return true +} + +func strictCompatibleArrayItems(node map[string]any) bool { + raw, ok := node["items"] + if !ok { + return true + } + if child, ok := raw.(map[string]any); ok { + return isStrictCompatibleChild(child) + } + items, ok := raw.([]any) + if !ok { + return false + } + for _, item := range items { + child, ok := item.(map[string]any) + if !ok || !isStrictCompatibleChild(child) { + return false + } + } + return true +} diff --git a/internal/plugins/backends/openaicodex/transport.go b/internal/plugins/backends/openaicodex/transport.go new file mode 100644 index 00000000..6a005001 --- /dev/null +++ b/internal/plugins/backends/openaicodex/transport.go @@ -0,0 +1,120 @@ +package openaicodex + +import ( + "context" + "errors" + "fmt" + "io" + "sync" + "time" +) + +// wsTransportError marks a WebSocket failure that occurs before the first +// canonical event: dial/handshake, send of response.create, or first-frame +// read/close/timeout. Only these errors trigger auto HTTPS fallback. +// wsWrappedError is the shared base for WebSocket error sentinels: it formats a +// backend-prefixed message and unwraps to the underlying cause. Concrete types +// embed it so errors.As can still discriminate between transport and read errors. +type wsWrappedError struct { + prefix string + cause error +} + +func (e *wsWrappedError) Error() string { + return fmt.Sprintf("%s: %s: %v", ID, e.prefix, e.cause) +} + +func (e *wsWrappedError) Unwrap() error { + return e.cause +} + +type wsTransportError struct { + wsWrappedError +} + +func newWSTransportError(cause error) error { + if cause == nil { + return nil + } + return &wsTransportError{wsWrappedError{prefix: "websocket transport", cause: cause}} +} + +func isWSTransportFailure(err error) bool { + var e *wsTransportError + return errors.As(err, &e) +} + +type wsStreamReadError struct { + wsWrappedError +} + +func newWSStreamReadError(cause error) error { + if cause == nil { + return nil + } + return &wsStreamReadError{wsWrappedError{prefix: "read websocket", cause: cause}} +} + +func isWSStreamReadError(err error) bool { + var e *wsStreamReadError + return errors.As(err, &e) +} + +func wsPreFirstEventFailure(err error) error { + if err == nil || isWSTransportFailure(err) { + return err + } + if errors.Is(err, io.EOF) || isWSStreamReadError(err) { + return newWSTransportError(err) + } + return err +} + +// transportCooldown is a negative cache for WebSocket attempts. When auto mode +// records a fallback-eligible WS failure, markFailed pushes the cooldown window +// forward so subsequent auto attempts skip WS and go straight to HTTPS until the +// window expires. +type transportCooldown struct { + mu sync.Mutex + until time.Time + cooldown time.Duration + now func() time.Time +} + +func newTransportCooldown(cooldown time.Duration) *transportCooldown { + if cooldown <= 0 { + cooldown = DefaultWebSocketFallbackCooldown + } + return &transportCooldown{cooldown: cooldown, now: time.Now} +} + +func (c *transportCooldown) active() bool { + if c == nil { + return false + } + c.mu.Lock() + defer c.mu.Unlock() + return c.now().Before(c.until) +} + +func (c *transportCooldown) markFailed() { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.until = c.now().Add(c.cooldown) +} + +// isWSFallbackError reports whether a WebSocket open failure should trigger auto +// fallback to HTTPS and record the cooldown. Context cancellation must never +// trigger fallback. Only typed pre-first-event transport failures qualify. +func isWSFallbackError(ctx context.Context, err error) bool { + if err == nil { + return false + } + if ctx != nil && ctx.Err() != nil { + return false + } + return isWSTransportFailure(err) +} diff --git a/internal/plugins/backends/openaicodex/transport_fallback_test.go b/internal/plugins/backends/openaicodex/transport_fallback_test.go new file mode 100644 index 00000000..52f08642 --- /dev/null +++ b/internal/plugins/backends/openaicodex/transport_fallback_test.go @@ -0,0 +1,228 @@ +package openaicodex_test + +import ( + "context" + "net/http/httptest" + "slices" + "testing" + "time" + + backend "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/backends/openaicodex" + refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// TestOpen_autoFallsBackWhenWSStopsAfterLifecycleEvent verifies that a bare +// response.created lifecycle event does not commit the WebSocket attempt. Auto +// mode may still fall back to HTTPS when WS closes before user-visible output. +func TestOpen_autoFallsBackWhenWSStopsAfterLifecycleEvent(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureAfterFirstEvent, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatalf("open: %v", err) + } + events := drainEvents(t, es) + + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (fallback before committed output)", got) + } + if !slices.Contains(eventKindsList(events), lipapi.EventTextDelta) { + t.Fatalf("missing HTTPS fallback text delta: %+v", events) + } +} + +// TestOpen_websocketModeSuccess exercises the strict websocket transport on the +// happy path: the call completes over WS and never touches HTTPS. +func TestOpen_websocketModeSuccess(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + events := drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "websocket" { + t.Fatalf("transport = %q, want websocket", got) + } + if !slices.Contains(eventKindsList(events), lipapi.EventTextDelta) { + t.Fatalf("missing text delta: %+v", events) + } +} + +// TestOpen_cooldownExpiryRetriesWS verifies that after the fallback cooldown +// window elapses, auto mode retries WebSocket (and succeeds) instead of staying +// on HTTPS permanently. +func TestOpen_cooldownExpiryRetriesWS(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "ws-ok-after", + ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + WebSocketFallbackCooldown: 50 * time.Millisecond, + }) + + first, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, first) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("first transport = %q, want https (fallback after WS fail)", got) + } + + deadline := time.After(2 * time.Second) + ticker := time.NewTicker(5 * time.Millisecond) + defer ticker.Stop() + for { + second, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, second) + if got := srv.LatestRequest().Transport; got == "websocket" { + return + } + select { + case <-deadline: + t.Fatalf("second transport = %q, want websocket (cooldown expired, auto must retry WS)", srv.LatestRequest().Transport) + case <-ticker.C: + } + } +} + +func TestOpen_autoFallsBackOnWSNormalCloseBeforeFirstEvent(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureNormalCloseBeforeEvent, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + WebSocketFallbackCooldown: 50 * time.Millisecond, + }) + + first, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, first) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("first transport = %q, want https (normal close before first event)", got) + } + + second, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, second) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("second transport = %q, want https (cooldown after normal close fallback)", got) + } +} + +func TestOpen_autoFallsBackOnWSNoCanonicalFirstFrameThenClose(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureNoCanonicalFirstFrame, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + WebSocketFallbackCooldown: 50 * time.Millisecond, + }) + + first, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, first) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("first transport = %q, want https (no canonical first frame then close)", got) + } + + second, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, second) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("second transport = %q, want https (cooldown after no-canonical fallback)", got) + } +} + +func TestOpen_autoNoFallbackOnWSMalformedFirstFrame(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureMalformedFirstFrame, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + _, err := be.Open(context.Background(), codexCall(), codexCand()) + if err == nil { + t.Fatal("expected open failure without HTTPS fallback") + } + if got := srv.LatestRequest().Transport; got == "https" { + t.Fatalf("transport = %q, must not fall back to HTTPS on mapper error", got) + } +} diff --git a/internal/plugins/backends/openaicodex/transport_internal_test.go b/internal/plugins/backends/openaicodex/transport_internal_test.go new file mode 100644 index 00000000..ae386da4 --- /dev/null +++ b/internal/plugins/backends/openaicodex/transport_internal_test.go @@ -0,0 +1,65 @@ +package openaicodex + +import ( + "errors" + "strings" + "testing" +) + +func TestWSTransportError_errorsContract(t *testing.T) { + t.Parallel() + cause := errors.New("dial boom") + err := newWSTransportError(cause) + if err == nil { + t.Fatal("newWSTransportError(non-nil) = nil, want error") + } + var target *wsTransportError + if !errors.As(err, &target) { + t.Fatalf("errors.As(*wsTransportError) = false, want true") + } + if !errors.Is(target, cause) { + t.Fatalf("errors.Is(target, cause) = false, want true (Unwrap must expose cause)") + } + if msg := err.Error(); !strings.Contains(msg, "websocket transport") || !strings.Contains(msg, "dial boom") { + t.Fatalf("Error() = %q, want it to contain %q and the cause", msg, "websocket transport") + } + if isWSStreamReadError(err) { + t.Fatalf("transport error must not match wsStreamReadError discriminator") + } +} + +func TestWSTransportError_nilCauseReturnsNil(t *testing.T) { + t.Parallel() + if err := newWSTransportError(nil); err != nil { + t.Fatalf("newWSTransportError(nil) = %v, want nil", err) + } +} + +func TestWSStreamReadError_errorsContract(t *testing.T) { + t.Parallel() + cause := errors.New("read boom") + err := newWSStreamReadError(cause) + if err == nil { + t.Fatal("newWSStreamReadError(non-nil) = nil, want error") + } + var target *wsStreamReadError + if !errors.As(err, &target) { + t.Fatalf("errors.As(*wsStreamReadError) = false, want true") + } + if !errors.Is(target, cause) { + t.Fatalf("errors.Is(target, cause) = false, want true (Unwrap must expose cause)") + } + if msg := err.Error(); !strings.Contains(msg, "read websocket") || !strings.Contains(msg, "read boom") { + t.Fatalf("Error() = %q, want it to contain %q and the cause", msg, "read websocket") + } + if isWSTransportFailure(err) { + t.Fatalf("stream read error must not match wsTransportError discriminator") + } +} + +func TestWSStreamReadError_nilCauseReturnsNil(t *testing.T) { + t.Parallel() + if err := newWSStreamReadError(nil); err != nil { + t.Fatalf("newWSStreamReadError(nil) = %v, want nil", err) + } +} diff --git a/internal/plugins/backends/openaicodex/transport_test.go b/internal/plugins/backends/openaicodex/transport_test.go new file mode 100644 index 00000000..9935c008 --- /dev/null +++ b/internal/plugins/backends/openaicodex/transport_test.go @@ -0,0 +1,206 @@ +package openaicodex_test + +import ( + "context" + "net/http/httptest" + "slices" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + backend "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/backends/openaicodex" + refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func codexCall() lipapi.Call { + return lipapi.Call{ + ID: "ws-call", + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + } +} + +func codexCand() routing.AttemptCandidate { + return routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} +} + +func eventKindsList(events []lipapi.Event) []lipapi.EventKind { + out := make([]lipapi.EventKind, 0, len(events)) + for _, ev := range events { + out = append(out, ev.Kind) + } + return out +} + +func TestOpen_autoTransportUsesWebSocket(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + events := drainEvents(t, es) + if err := lipapi.ValidateEventSequence(events); err != nil { + t.Fatal(err) + } + if got := srv.LatestRequest().Transport; got != "websocket" { + t.Fatalf("transport = %q, want websocket (auto must try WS first)", got) + } + if !slices.Contains(eventKindsList(events), lipapi.EventTextDelta) { + t.Fatalf("missing text delta: %+v", events) + } +} + +func TestOpen_defaultTransportUsesHTTPS(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "https-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (websocket is experimental opt-in)", got) + } +} + +func TestOpen_autoFallsBackToHTTPOnWSFail(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "http-ok", ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (auto must fall back after WS fail-before-first-event)", got) + } +} + +func TestOpen_websocketModeNoFallback(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok", ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportWebSocket, + ExperimentalWebSocket: true, + }) + _, err := be.Open(context.Background(), codexCall(), codexCand()) + if err == nil { + t.Fatal("expected websocket-only mode to surface WS failure without fallback") + } + if got := srv.LatestRequest().Transport; got == "https" { + t.Fatalf("websocket mode must not fall back to HTTPS; transport=%q", got) + } +} + +func TestOpen_httpsModeNeverUsesWebSocket(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportHTTPS, + }) + es, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, es) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (https mode must never dial WS)", got) + } +} + +func TestOpen_cooldownSkipsWSAfterFailure(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ok", ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: backend.TransportAuto, + ExperimentalWebSocket: true, + }) + first, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, first) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("first call transport = %q, want https (fallback)", got) + } + + // Second call: WS would now succeed (one-shot fail consumed), but cooldown + // must keep auto mode on HTTPS. + second, err := be.Open(context.Background(), codexCall(), codexCand()) + if err != nil { + t.Fatal(err) + } + drainEvents(t, second) + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("second call transport = %q, want https (cooldown must skip WS)", got) + } +} + +func TestOpen_invalidTransportConfigError(t *testing.T) { + t.Parallel() + ts := httptest.NewServer(refbackend.New(refbackend.Config{Token: "sk-codex"}).Handler()) + t.Cleanup(ts.Close) + be := backend.New(backend.Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: "quic", + }) + _, err := be.Open(context.Background(), codexCall(), codexCand()) + if err == nil { + t.Fatal("expected config error for invalid transport") + } + _, err = be.ModelInventory.LoadModels(context.Background()) + if err == nil { + t.Fatal("expected inventory config error for invalid transport") + } +} diff --git a/internal/plugins/backends/openaicodex/usage_estimator_test.go b/internal/plugins/backends/openaicodex/usage_estimator_test.go index 318d1be9..c7a50c46 100644 --- a/internal/plugins/backends/openaicodex/usage_estimator_test.go +++ b/internal/plugins/backends/openaicodex/usage_estimator_test.go @@ -21,7 +21,7 @@ func TestEstimateUsage_textRequest_positiveTotalsAndMetadata(t *testing.T) { Parts: []lipapi.Part{lipapi.TextPart("hello codex")}, }}, } - ev, err := est.estimateUsage(context.Background(), call, "gpt-5.3-codex", "world") + ev, err := est.estimateUsage(context.Background(), call, "gpt-5.3-codex-spark", "world") if err != nil { t.Fatal(err) } @@ -46,7 +46,7 @@ func TestEstimateUsage_textRequest_positiveTotalsAndMetadata(t *testing.T) { if ev.Accounting.Tokenizer.Source != "github.com/tiktoken-go/tokenizer" { t.Fatalf("tokenizer source=%q", ev.Accounting.Tokenizer.Source) } - if ev.Accounting.Tokenizer.ModelUsed != "gpt-5.3-codex" { + if ev.Accounting.Tokenizer.ModelUsed != "gpt-5.3-codex-spark" { t.Fatalf("model used=%q", ev.Accounting.Tokenizer.ModelUsed) } } @@ -138,7 +138,7 @@ func TestEstimateUsage_imageRefURL_usesConservativeDefault(t *testing.T) { }, }}, } - ev, err := est.estimateUsage(context.Background(), call, "gpt-5.3-codex", "done") + ev, err := est.estimateUsage(context.Background(), call, "gpt-5.3-codex-spark", "done") if err != nil { t.Fatal(err) } diff --git a/internal/plugins/backends/openaicodex/ws.go b/internal/plugins/backends/openaicodex/ws.go new file mode 100644 index 00000000..3c54ad17 --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws.go @@ -0,0 +1,808 @@ +package openaicodex + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/backends/streampeek" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +const wsHandshakeTimeout = 30 * time.Second + +const ( + wsSessionIdleTTL = 2 * time.Minute + wsSessionMaxEntries = 256 +) + +// wsFirstEventTimeout bounds the wait for the first canonical event after the +// WebSocket handshake. Without it, a server that upgrades but never sends would +// leave openWS blocked forever on conn.ReadMessage (which ignores ctx). It is a +// package var instead of a const so internal tests can shorten it; production +// callers always see the default. +var wsFirstEventTimeout = 30 * time.Second + +var errWSPreviousResponseNotFound = errors.New("websocket previous response not found") + +type wsSessionKey struct { + baseURL string + accountID string + accessToken string + conversation string +} + +type wsSessionStore struct { + mu sync.Mutex + sessions map[wsSessionKey]*wsSessionConn + idleTTL time.Duration + maxEntries int + now func() time.Time +} + +type wsSessionConn struct { + key wsSessionKey + store *wsSessionStore + sem chan struct{} + conn *websocket.Conn + lastUsed time.Time + idleTimer *time.Timer +} + +func newWSSessionStore() *wsSessionStore { + return &wsSessionStore{ + sessions: make(map[wsSessionKey]*wsSessionConn), + idleTTL: wsSessionIdleTTL, + maxEntries: wsSessionMaxEntries, + now: time.Now, + } +} + +func (s *wsSessionStore) acquire(ctx context.Context, client *http.Client, url string, cfg *Config, convID string) (*wsSessionConn, *http.Response, bool, error) { + key := wsSessionKey{ + baseURL: strings.TrimSpace(url), + accountID: strings.TrimSpace(cfg.AccountID), + accessToken: strings.TrimSpace(cfg.AccessToken), + conversation: strings.TrimSpace(convID), + } + s.mu.Lock() + session := s.sessions[key] + if session == nil { + session = &wsSessionConn{ + key: key, + store: s, + sem: make(chan struct{}, 1), + lastUsed: s.now(), + } + s.sessions[key] = session + s.pruneToCapLocked(session) + } + session.stopIdleTimerLocked() + s.mu.Unlock() + + if err := session.acquire(ctx); err != nil { + return nil, nil, false, err + } + if session.conn != nil { + return session, nil, true, nil + } + conn, resp, err := dialCodexWebSocket(ctx, client, url, cfg, convID) + if err != nil { + session.release(true) + return nil, resp, false, err + } + session.conn = conn + return session, resp, false, nil +} + +func (s *wsSessionStore) forgetLocked(key wsSessionKey, session *wsSessionConn) { + if s.sessions[key] == session { + delete(s.sessions, key) + } +} + +func (s *wsSessionStore) pruneToCapLocked(protected *wsSessionConn) { + for len(s.sessions) > s.maxEntries { + var oldestKey wsSessionKey + var oldest *wsSessionConn + for key, session := range s.sessions { + if session == protected { + continue + } + if !session.tryAcquire() { + continue + } + if oldest == nil || session.lastUsed.Before(oldest.lastUsed) { + if oldest != nil { + oldest.unlock() + } + oldestKey = key + oldest = session + continue + } + session.unlock() + } + if oldest == nil { + return + } + oldest.closeConnLocked() + s.forgetLocked(oldestKey, oldest) + oldest.unlock() + } +} + +func (s *wsSessionStore) closeIdle(key wsSessionKey, session *wsSessionConn) { + if !session.tryAcquire() { + return + } + defer session.unlock() + s.mu.Lock() + defer s.mu.Unlock() + session.closeConnLocked() + session.stopIdleTimerLocked() + s.forgetLocked(key, session) +} + +func (s *wsSessionConn) acquire(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + select { + case s.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *wsSessionConn) tryAcquire() bool { + select { + case s.sem <- struct{}{}: + return true + default: + return false + } +} + +func (s *wsSessionConn) unlock() { + select { + case <-s.sem: + default: + } +} + +func (s *wsSessionConn) release(closeConn bool) { + if s.store == nil { + s.unlock() + return + } + s.store.mu.Lock() + if closeConn { + s.closeConnLocked() + s.stopIdleTimerLocked() + s.store.forgetLocked(s.key, s) + } else { + s.lastUsed = s.store.now() + s.scheduleIdleTimerLocked() + } + s.store.mu.Unlock() + s.unlock() +} + +func (s *wsSessionConn) closeConnLocked() { + if s.conn == nil { + return + } + _ = s.conn.Close() + s.conn = nil +} + +func (s *wsSessionConn) stopIdleTimerLocked() { + if s.idleTimer == nil { + return + } + s.idleTimer.Stop() + s.idleTimer = nil +} + +func (s *wsSessionConn) scheduleIdleTimerLocked() { + s.stopIdleTimerLocked() + if s.store == nil || s.store.idleTTL <= 0 { + return + } + key := s.key + store := s.store + s.idleTimer = time.AfterFunc(s.store.idleTTL, func() { + store.closeIdle(key, s) + }) +} + +// wsEndpoint converts an HTTPS Codex base URL into the WebSocket scheme used by +// the Codex Responses WebSocket transport. Path handling mirrors +// responsesEndpoint so the same base_url value configures both transports. +func wsEndpoint(baseURL string) string { + base := normalizedResponsesBase(baseURL) + switch { + case strings.HasPrefix(base, "https://"): + return "wss://" + strings.TrimPrefix(base, "https://") + case strings.HasPrefix(base, "http://"): + return "ws://" + strings.TrimPrefix(base, "http://") + default: + return base + } +} + +func newWSDialer(client *http.Client) *websocket.Dialer { + d := &websocket.Dialer{HandshakeTimeout: wsHandshakeTimeout} + if client != nil { + if t, ok := client.Transport.(*http.Transport); ok && t != nil { + d.Proxy = t.Proxy + d.NetDialContext = t.DialContext + if t.TLSClientConfig != nil { + d.TLSClientConfig = t.TLSClientConfig.Clone() + } else { + d.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} + } + } + // When client.Transport is a custom RoundTripper (e.g. instrumentation) + // rather than *http.Transport, proxy/TLS settings cannot be introspected + // generically, so the WS dialer falls back to default networking. This + // differs from the HTTPS path that uses the same client. + } + return d +} + +// openWS dials the Codex Responses WebSocket, sends a response.create frame, and +// returns a managed event stream after the first canonical event is received. +// A failure before the first canonical event is returned as an error so the +// auto transport can fall back to HTTPS. +func openWS(ctx context.Context, cfg *Config, policy downgradePolicy, usageEst *usageEstimator, sessions *wsSessionStore, continuation *wsContinuationStore, call lipapi.Call, cand routing.AttemptCandidate) (lipapi.ManagedEventStream, error) { + env, err := prepareCodexOpenEnv(ctx, cfg, call, cand, policy) + if err != nil { + return nil, err + } + es, _, err := openWSPrepared(ctx, env, cfg, policy.modelForPlan(env.originalModel, cfg.PlanTypeHint), call, usageEst, sessions, continuation) + return es, err +} + +func openWSPrepared(ctx context.Context, env *codexOpenEnv, cfg *Config, model string, call lipapi.Call, usageEst *usageEstimator, sessions *wsSessionStore, continuation *wsContinuationStore) (lipapi.ManagedEventStream, *http.Response, error) { + es, resp, rawFirst, err := openWSPreparedAttempt(ctx, env, cfg, model, call, usageEst, sessions, continuation) + if err == nil { + return es, resp, nil + } + if !isWSFreePlanRejection(rawFirst, env.downgrade, env.originalModel) { + return nil, resp, err + } + es, resp, _, err = openWSPreparedAttempt(ctx, env, cfg, env.downgrade.target, call, usageEst, sessions, continuation) + return es, resp, err +} + +type wsOpenRetryDecision int + +const ( + wsOpenNoRetry wsOpenRetryDecision = iota + wsOpenRetryFreshSession + wsOpenRetryWithoutContinuation +) + +type wsOpenAttemptState struct { + allowContinuation bool + allowStaleRetry bool +} + +func openWSPreparedAttempt(ctx context.Context, env *codexOpenEnv, cfg *Config, model string, call lipapi.Call, usageEst *usageEstimator, sessions *wsSessionStore, continuation *wsContinuationStore) (lipapi.ManagedEventStream, *http.Response, []byte, error) { + state := wsOpenAttemptState{ + allowContinuation: true, + allowStaleRetry: true, + } + for { + es, resp, rawFirst, retry, err := openWSPreparedAttemptOnce(ctx, env, cfg, model, call, usageEst, sessions, continuation, state) + switch retry { + case wsOpenNoRetry: + return es, resp, rawFirst, err + case wsOpenRetryFreshSession: + state.allowStaleRetry = false + case wsOpenRetryWithoutContinuation: + state.allowContinuation = false + state.allowStaleRetry = false + } + } +} + +func openWSPreparedAttemptOnce(ctx context.Context, env *codexOpenEnv, cfg *Config, model string, call lipapi.Call, usageEst *usageEstimator, sessions *wsSessionStore, continuation *wsContinuationStore, state wsOpenAttemptState) (lipapi.ManagedEventStream, *http.Response, []byte, wsOpenRetryDecision, error) { + if sessions == nil { + sessions = newWSSessionStore() + } + if continuation == nil { + continuation = newWSContinuationStore(codexContinuationTTL, codexContinuationMaxEntries) + } + env.payload.Model = model + fullPayload := env.payload + fullInputFingerprints := append([]string(nil), env.inputFingerprints...) + continuationApplied := state.allowContinuation && continuation.prepareWithFingerprints(ctx, cfg, call, &env.payload, fullInputFingerprints) + clearPreparedContinuation := func() { + if continuationApplied { + continuation.invalidateWithFingerprints(cfg, call, &fullPayload, fullInputFingerprints) + } + } + restoreFullPayload := func() { + env.payload = fullPayload + } + frame, err := payloadToWSResponseCreate(env.payload) + if err != nil { + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, nil, wsOpenNoRetry, err + } + session, resp, reusedSession, err := sessions.acquire(ctx, env.client, wsEndpoint(cfg.BaseURL), cfg, env.convID) + if err != nil { + clearPreparedContinuation() + // Restore the full payload snapshot before returning so a rotation retry on + // another account does not inherit this attempt's continuation-trimmed Input + // and PreviousResponseID. The other retry paths restore below for the same + // reason; the handshake-error path must too because it hands resp back to the + // managed loop, which rotates accounts on 401/403/429 reusing this env. + restoreFullPayload() + // Return the (body-closed) handshake response so the managed WS path can + // classify 401/403/429 handshakes and rotate to the next account. + return nil, resp, nil, wsOpenNoRetry, err + } + conn := session.conn + if err := writeWSResponseCreate(ctx, conn, frame); err != nil { + session.release(true) + if reusedSession && state.allowStaleRetry { + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, nil, wsOpenRetryFreshSession, err + } + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, nil, wsOpenNoRetry, err + } + effectiveModel := strings.TrimSpace(env.payload.Model) + if effectiveModel == "" { + effectiveModel = env.originalModel + } + // Read the first raw frame directly so a pre-content model rejection can be + // detected before canonical mapping: the mapper synthesizes a ResponseStarted + // event ahead of an EventError, which would hide the rejection from a + // first-canonical-event check. + rawFirst, rerr := readFirstNonEmptyWSMessage(ctx, conn, wsFirstEventTimeout) + if rerr != nil { + session.release(true) + if reusedSession && state.allowStaleRetry && isWSFallbackError(ctx, rerr) { + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, nil, wsOpenRetryFreshSession, rerr + } + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, nil, wsOpenNoRetry, rerr + } + if isWSFreePlanRejection(rawFirst, env.downgrade, env.originalModel) { + session.release(true) + clearPreparedContinuation() + restoreFullPayload() + return nil, resp, rawFirst, wsOpenNoRetry, fmt.Errorf("%s: websocket model rejected before first event", ID) + } + mapper := newCodexEventMapper(call.MaxPendingWireEvents) + if err := mapper.handleData(string(rawFirst)); err != nil { + session.release(true) + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, rawFirst, wsOpenNoRetry, err + } + wsStream := newWSStreamWithMapper(conn, mapper) + wsStream.release = session.release + var managed lipapi.ManagedEventStream + if cfg.Transport == TransportWebSocket { + // Strict WS mode returns as soon as the first canonical event is available. + // This mirrors the HTTPS open contract and lets the frontend stream + // response.started immediately. Waiting here for committed output is only + // needed in auto mode, where the transport must still be able to downgrade + // to HTTPS before any downstream-visible content commits. + managed, rerr = openManagedFirstEvent(ctx, wsStream, usageEst, call, effectiveModel) + } else { + managed, rerr = openManagedUntilCommitted(ctx, wsStream, usageEst, call, effectiveModel, wsFirstEventTimeout) + } + if rerr != nil { + if continuationApplied && errors.Is(rerr, errWSPreviousResponseNotFound) { + continuation.invalidateWithFingerprints(cfg, call, &fullPayload, fullInputFingerprints) + restoreFullPayload() + wsStream.releaseOnce(true) + return nil, nil, rawFirst, wsOpenRetryWithoutContinuation, rerr + } + clearPreparedContinuation() + restoreFullPayload() + return nil, nil, rawFirst, wsOpenNoRetry, wsPreFirstEventFailure(rerr) + } + managed = newCodexContinuationRecordingStream(managed, cfg, call, fullPayload, fullInputFingerprints, mapper, continuation) + // The opening boundary has been reached: strict websocket mode returns after + // the first canonical event, while auto mode waits until output is committed + // or terminal. Clear the deadline so subsequent streaming reads are governed + // by caller contexts rather than the open-time fallback window. + _ = conn.SetReadDeadline(time.Time{}) + return managed, resp, rawFirst, wsOpenNoRetry, nil +} + +func openManagedUntilCommitted(ctx context.Context, es lipapi.ManagedEventStream, usageEst *usageEstimator, call lipapi.Call, model string, timeout time.Duration) (lipapi.ManagedEventStream, error) { + managed := newUsageEstimatingStream(es, usageEst, call, model) + recvCtx := ctx + cancel := func() {} + if timeout > 0 { + recvCtx, cancel = context.WithTimeout(ctx, timeout) + } + defer cancel() + + var first []lipapi.Event + for { + ev, err := managed.Recv(recvCtx) + if err != nil { + _ = managed.Close() + if ctx != nil && ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, newWSTransportError(err) + } + first = append(first, ev) + if ev.Kind == lipapi.EventError { + _ = managed.Close() + if ev.ErrorCode == "previous_response_not_found" { + return nil, errWSPreviousResponseNotFound + } + return nil, fmt.Errorf("%s: upstream websocket error before output: %s", ID, ev.ErrorMessage) + } + if wsOpenCommitted(ev) { + return prependManagedEvents(first, managed), nil + } + } +} + +func wsOpenCommitted(ev lipapi.Event) bool { + return lipapi.OutputCommitted(ev) || ev.Kind == lipapi.EventError || ev.Kind == lipapi.EventResponseFinished +} + +func prependManagedEvents(events []lipapi.Event, rest lipapi.ManagedEventStream) lipapi.ManagedEventStream { + out := rest + for i := len(events) - 1; i >= 0; i-- { + out = streampeek.NewManagedPrependFirst(events[i], out) + } + return out +} + +var _ lipapi.ManagedEventStream = (*codexContinuationRecordingStream)(nil) + +type codexContinuationRecordingStream struct { + inner lipapi.ManagedEventStream + cfg *Config + call lipapi.Call + payload Payload + inputFP []string + mapper *codexEventMapper + store *wsContinuationStore + once sync.Once + mu sync.Mutex + recorded bool +} + +func newCodexContinuationRecordingStream(inner lipapi.ManagedEventStream, cfg *Config, call lipapi.Call, payload Payload, inputFingerprints []string, mapper *codexEventMapper, store *wsContinuationStore) lipapi.ManagedEventStream { + return &codexContinuationRecordingStream{ + inner: inner, + cfg: cfg, + call: call, + payload: payload, + inputFP: append([]string(nil), inputFingerprints...), + mapper: mapper, + store: store, + } +} + +func (s *codexContinuationRecordingStream) Recv(ctx context.Context) (lipapi.Event, error) { + ev, err := s.inner.Recv(ctx) + if err == nil && ev.Kind == lipapi.EventResponseFinished { + s.record() + } + return ev, err +} + +func (s *codexContinuationRecordingStream) Close() error { + err := s.inner.Close() + if !s.wasRecorded() { + s.store.invalidateWithFingerprints(s.cfg, s.call, &s.payload, s.inputFP) + } + return err +} + +func (s *codexContinuationRecordingStream) Cancel(ctx context.Context, cause lipapi.CancelCause) lipapi.CancelResult { + res := s.inner.Cancel(ctx, cause) + if !s.wasRecorded() { + s.store.invalidateWithFingerprints(s.cfg, s.call, &s.payload, s.inputFP) + } + return res +} + +func (s *codexContinuationRecordingStream) record() { + s.once.Do(func() { + if s.mapper == nil { + return + } + if strings.TrimSpace(s.mapper.responseID) == "" { + return + } + s.mu.Lock() + s.recorded = true + s.mu.Unlock() + s.store.recordWithFingerprints(s.cfg, s.call, s.payload, s.inputFP, s.mapper.responseID, s.mapper.outputItems...) + }) +} + +func (s *codexContinuationRecordingStream) wasRecorded() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.recorded +} + +// readFirstNonEmptyWSMessage reads WebSocket text frames, skipping empty ones, until +// the first non-empty frame arrives. Pre-first-event read/close failures are wrapped +// as wsTransportError so auto mode can fall back to HTTPS. The caller sets the read +// deadline. +func readFirstNonEmptyWSMessage(ctx context.Context, conn *websocket.Conn, timeout time.Duration) ([]byte, error) { + if timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + stopCancel := func() bool { return true } + if ctx != nil { + stopCancel = context.AfterFunc(ctx, func() { + _ = conn.SetReadDeadline(time.Now()) + }) + } + defer stopCancel() + for { + _, data, err := conn.ReadMessage() + if err != nil { + if ctx != nil && ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, newWSTransportError(fmt.Errorf("read websocket: %w", err)) + } + if len(strings.TrimSpace(string(data))) > 0 { + return data, nil + } + } +} + +func writeWSResponseCreate(ctx context.Context, conn *websocket.Conn, frame json.RawMessage) error { + if wsFirstEventTimeout > 0 { + _ = conn.SetWriteDeadline(time.Now().Add(wsFirstEventTimeout)) + } + stopCancel := func() bool { return true } + if ctx != nil { + stopCancel = context.AfterFunc(ctx, func() { + _ = conn.SetWriteDeadline(time.Now()) + }) + } + err := conn.WriteJSON(frame) + stopCancel() + _ = conn.SetWriteDeadline(time.Time{}) + if err == nil { + return nil + } + if ctx != nil && ctx.Err() != nil { + return ctx.Err() + } + return newWSTransportError(fmt.Errorf("websocket send response.create: %w", err)) +} + +// isWSFreePlanRejection reports whether a raw WebSocket frame is a pre-content error +// event whose message matches a free-plan gpt-5.5 rejection. Mirrors the HTTP path's +// downgradePolicy.isFreePlanRejection but operates on an error event frame instead of +// an HTTP status+body pair, since the WebSocket transport has no status code. +func isWSFreePlanRejection(rawFrame []byte, policy downgradePolicy, originalModel string) bool { + var probe struct { + Type string `json:"type"` + Error *struct { + Message string `json:"message"` + } `json:"error"` + } + if err := json.Unmarshal(rawFrame, &probe); err != nil { + return false + } + if probe.Type != "error" || probe.Error == nil { + return false + } + return policy.shouldReactiveRetry(originalModel, false, probe.Error.Message) +} + +func dialCodexWebSocket(ctx context.Context, client *http.Client, url string, cfg *Config, convID string) (*websocket.Conn, *http.Response, error) { + d := newWSDialer(client) + conn, resp, err := d.DialContext(ctx, url, codexWSHeaders(*cfg, convID)) + if err != nil { + if resp != nil { + // Body is closed but resp.StatusCode/Header remain readable so callers + // (e.g. managed WS rotation) can classify 401/403/429 handshakes. + _ = resp.Body.Close() + return nil, resp, newWSTransportError(fmt.Errorf("websocket dial: %w (status=%s)", err, resp.Status)) + } + return nil, nil, newWSTransportError(fmt.Errorf("websocket dial: %w", err)) + } + return conn, resp, nil +} + +const wsFrameTypeResponseCreate = "response.create" + +type wsResponseCreateFrame struct { + Type string `json:"type"` + Payload +} + +// payloadToWSResponseCreate builds a WebSocket response.create frame from a Codex +// HTTPS payload: same fields with stream omitted and type set explicitly. +func payloadToWSResponseCreate(p Payload) (json.RawMessage, error) { + p.Stream = false + frame := wsResponseCreateFrame{ + Type: wsFrameTypeResponseCreate, + Payload: p, + } + out, err := json.Marshal(frame) + if err != nil { + return nil, fmt.Errorf("%s: marshal ws frame: %w", ID, err) + } + return out, nil +} + +var _ lipapi.ManagedEventStream = (*wsStream)(nil) + +type wsStream struct { + mapper *codexEventMapper + mu sync.Mutex + conn *websocket.Conn + closed bool + release func(closeConn bool) + releaseOnceF sync.Once +} + +func newWSStream(conn *websocket.Conn, maxPending int) *wsStream { + return newWSStreamWithMapper(conn, newCodexEventMapper(maxPending)) +} + +// newWSStreamWithMapper builds a wsStream over a pre-existing event mapper. The caller +// may have already populated the mapper's pending queue (e.g. from a pre-read first +// frame); the stream's Recv drains pending before reading the next wire frame. +func newWSStreamWithMapper(conn *websocket.Conn, mapper *codexEventMapper) *wsStream { + return &wsStream{ + mapper: mapper, + conn: conn, + } +} + +func (s *wsStream) Recv(ctx context.Context) (lipapi.Event, error) { + if ctx == nil { + return lipapi.Event{}, lipapi.ErrNilContext + } + if err := ctx.Err(); err != nil { + return lipapi.Event{}, err + } + for { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return lipapi.Event{}, io.EOF + } + if ev, ok := s.mapper.pending.PopFront(); ok { + s.mu.Unlock() + return ev, nil + } + if s.mapper.terminal { + s.mu.Unlock() + s.releaseOnce(false) + return lipapi.Event{}, io.EOF + } + s.mu.Unlock() + + text, ok, err := s.readMessage(ctx) + if err != nil { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return lipapi.Event{}, io.EOF + } + s.releaseOnce(true) + return lipapi.Event{}, err + } + if !ok { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return lipapi.Event{}, io.EOF + } + s.releaseOnce(false) + return lipapi.Event{}, io.EOF + } + if text == "" { + continue + } + + s.mu.Lock() + if s.closed { + s.mu.Unlock() + continue + } + if err := s.mapper.handleData(text); err != nil { + s.mu.Unlock() + return lipapi.Event{}, err + } + s.mu.Unlock() + } +} + +func (s *wsStream) readMessage(ctx context.Context) (string, bool, error) { + stopCancel := context.AfterFunc(ctx, func() { + _ = s.conn.SetReadDeadline(time.Now()) + }) + defer stopCancel() + _, data, err := s.conn.ReadMessage() + if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + _ = s.conn.SetReadDeadline(time.Time{}) + return "", false, ctxErr + } + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return "", false, io.EOF + } + return "", false, newWSStreamReadError(err) + } + text := strings.TrimSpace(string(data)) + if text == "" { + return "", true, nil + } + return text, true, nil +} + +func (s *wsStream) Close() error { + closeConn := true + s.mu.Lock() + if s.mapper.terminal { + closeConn = false + } + s.mu.Unlock() + if closeConn && s.conn != nil { + // Close first, without taking s.mu: Recv holds that lock while blocked in + // ReadMessage, so taking it before closing would deadlock cancellation. + _ = s.conn.Close() + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + s.closed = true + s.releaseOnce(closeConn) + return nil +} + +func (s *wsStream) Cancel(context.Context, lipapi.CancelCause) lipapi.CancelResult { + // Codex WebSocket does not have a request-cancel frame in this adapter. Close + // the socket instead of pretending cancellation is protocol-level; this also + // prevents reuse of an in-flight session whose upstream generation may still be + // producing frames. + return lipapi.CancelResult{Mode: lipapi.CancelModeCloseOnly, Err: s.Close()} +} + +func (s *wsStream) releaseOnce(closeConn bool) { + s.releaseOnceF.Do(func() { + if s.release != nil { + s.release(closeConn) + } + }) +} diff --git a/internal/plugins/backends/openaicodex/ws_continuation_rotation_internal_test.go b/internal/plugins/backends/openaicodex/ws_continuation_rotation_internal_test.go new file mode 100644 index 00000000..929ff66f --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws_continuation_rotation_internal_test.go @@ -0,0 +1,144 @@ +package openaicodex + +import ( + "context" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// TestWSContinuationRotation_preservesFullPayloadAcrossAccounts is a regression +// repro for the managed-OAuth WebSocket account-rotation path reported by Bugbot. +// +// When the first account's WS handshake returns 401/403/429, openWSPreparedAttemptOnce +// clears the prepared continuation entry but does NOT restore env.payload from the +// fullPayload snapshot (unlike the write/read/previous-response retry paths, which +// all do env.payload = fullPayload). Because openManagedAccountLoop reuses a single +// *codexOpenEnv across account retries, the next account is dialed with a +// continuation-trimmed Input and a PreviousResponseID that belongs to the first +// account. +// +// This test asserts the correct behavior (full input, no foreign previous_response_id +// on the rotated account) and is expected to FAIL on the current code, proving the +// finding is not a false positive. +func TestWSContinuationRotation_preservesFullPayloadAcrossAccounts(t *testing.T) { + t.Parallel() + + // Two managed accounts. The refbackend is configured with account B's token, so + // account A's WS handshake returns 401 and openManagedAccountLoop rotates to B. + dir := t.TempDir() + accountFiles := []struct{ name, id, token string }{ + {"a.json", "acct-a", "tok-a"}, + {"b.json", "acct-b", "tok-b"}, + } + for _, af := range accountFiles { + path := filepath.Join(dir, af.name) + if err := os.WriteFile(path, []byte(`{"account_id":"`+af.id+`","access_token":"`+af.token+`"}`), 0o600); err != nil { + t.Fatal(err) + } + } + store, err := newAccountStore(Config{ + ManagedOAuthStoragePath: dir, + ManagedOAuthSelectionStrategy: "first-available", + RateLimitFallback: time.Hour, + }) + if err != nil { + t.Fatal(err) + } + + srv := refbackend.New(refbackend.Config{Token: "tok-b", OutputText: "ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + cfg := Config{ + BaseURL: ts.URL + "/backend-api/codex", + HTTPClient: ts.Client(), + Transport: TransportWebSocket, + } + policy := newDowngradePolicy(cfg) + + call := lipapi.Call{ + ID: "repro-rotation-call", + Session: lipapi.SessionRef{ContinuityKey: "conv-repro"}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("inspect")}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("continue")}}, + }, + } + cand := routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} + + // Reconstruct the exact payload openManagedAccountLoop will build, so the seeded + // continuation entry for account A has matching Instructions/Tools/PromptCacheKey + // fingerprints and is actually applied on the first attempt. + seedEnv, err := prepareCodexOpenEnv(context.Background(), &cfg, call, cand, policy) + if err != nil { + t.Fatal(err) + } + if got := len(seedEnv.payload.Input); got != 2 { + t.Fatalf("seed payload input len = %d, want 2 (inspect + continue)", got) + } + basePayload := seedEnv.payload + basePayload.Input = append([]inputItem(nil), seedEnv.payload.Input[:1]...) // [user "inspect"] + + cfgA := cfg + cfgA.AccountID = "acct-a" + + // Sanity check on a throwaway store: the seeded entry must be applied for account A, + // trimming Input and setting PreviousResponseID. This proves the rotation path + // actually enters the continuation branch before the 401 handshake. + checkStore := newWSContinuationStore(time.Minute, 8) + checkStore.record(&cfgA, call, basePayload, "resp_a") + checkPayload := seedEnv.payload + if !checkStore.prepareWithFingerprints(context.Background(), &cfgA, call, &checkPayload, seedEnv.inputFingerprints) { + t.Fatal("seeded continuation entry was not applied for account A; repro setup is broken") + } + if checkPayload.PreviousResponseID != "resp_a" || len(checkPayload.Input) != 1 { + t.Fatalf("sanity check: continuation did not trim payload: prev=%q input=%#v", + checkPayload.PreviousResponseID, checkPayload.Input) + } + + // Real run: seed the continuation store used by openManagedWS. + continuation := newWSContinuationStore(time.Minute, 8) + continuation.record(&cfgA, call, basePayload, "resp_a") + + es, err := openManagedWS(context.Background(), &cfg, store, call, cand, policy, nil, newWSSessionStore(), continuation) + if err != nil { + t.Fatalf("openManagedWS: %v", err) + } + t.Cleanup(func() { _ = es.Close() }) + // The refbackend captures the response.create frame before sending any event, so + // the first Recv guarantees the wire payload is recorded. + if _, err := es.Recv(context.Background()); err != nil { + t.Fatalf("first Recv: %v", err) + } + + captured := srv.LatestRequest() + if captured.Transport != "websocket" { + t.Fatalf("captured transport = %q, want websocket", captured.Transport) + } + if got := captured.Authorization; got != "Bearer tok-b" { + t.Fatalf("captured authorization = %q, want %q (rotation must reach account B)", got, "Bearer tok-b") + } + + input, _ := captured.Body["input"].([]any) + if len(input) != 2 { + t.Fatalf("rotated account B received input len = %d, want 2 (full payload must be sent to the fresh account); input=%#v", + len(input), input) + } + first, _ := input[0].(map[string]any) + if first["content"] != "inspect" { + t.Fatalf("rotated account B first input content = %v, want %q (continuation-trimmed payload leaked from account A)", + first["content"], "inspect") + } + + if prev, ok := captured.Body["previous_response_id"]; ok { + t.Fatalf("rotated account B received previous_response_id = %v; continuation state from account A must not leak across accounts", + prev) + } +} diff --git a/internal/plugins/backends/openaicodex/ws_first_event_internal_test.go b/internal/plugins/backends/openaicodex/ws_first_event_internal_test.go new file mode 100644 index 00000000..5160d739 --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws_first_event_internal_test.go @@ -0,0 +1,225 @@ +package openaicodex + +import ( + "context" + "errors" + "net" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func internalCodexCall() lipapi.Call { + return lipapi.Call{ + ID: "ws-stall-call", + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + } +} + +func internalCodexCand() routing.AttemptCandidate { + return routing.AttemptCandidate{Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}} +} + +// drainUntilEnd reads events until the stream terminates, accepting either EOF +// or a stream error as a valid end. Used where a mid-stream failure is expected. +func drainUntilEnd(t *testing.T, es lipapi.ManagedEventStream) { + t.Helper() + for { + ev, err := es.Recv(context.Background()) + if err != nil { + _ = es.Close() + return + } + _ = ev + } +} + +func TestNewWSDialerCopiesHTTPTransportDialer(t *testing.T) { + t.Parallel() + var dialed bool + client := &http.Client{Transport: &http.Transport{ + DialContext: func(context.Context, string, string) (net.Conn, error) { + dialed = true + return nil, errors.New("custom dialer used") + }, + }} + + d := newWSDialer(client) + if d.NetDialContext == nil { + t.Fatal("websocket dialer did not copy http.Transport.DialContext") + } + _, err := d.NetDialContext(context.Background(), "tcp", "example.invalid:443") + if err == nil { + t.Fatal("expected custom dialer error") + } + if !dialed { + t.Fatal("custom transport dialer was not used") + } +} + +// TestOpen_autoFallsBackOnWSStallWithinTimeout verifies that when the WebSocket +// upgrade succeeds but the server never sends a first event, auto transport does +// not hang: the first-event read deadline fires and the call falls back to HTTPS +// within the configured window. Not parallel: temporarily shortens the package +// var wsFirstEventTimeout. +// +//nolint:paralleltest +func TestOpen_autoFallsBackOnWSStallWithinTimeout(t *testing.T) { + prev := wsFirstEventTimeout + wsFirstEventTimeout = 200 * time.Millisecond + t.Cleanup(func() { wsFirstEventTimeout = prev }) + + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureStall, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := New(Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: TransportAuto, + ExperimentalWebSocket: true, + }) + + start := time.Now() + es, err := be.Open(context.Background(), internalCodexCall(), internalCodexCand()) + if err != nil { + t.Fatalf("open: %v", err) + } + drainUntilEnd(t, es) + + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (auto must fall back when WS stalls before first event)", got) + } + if elapsed := time.Since(start); elapsed > 2*time.Second { + t.Fatalf("fallback took %v, expected to be bounded by the short first-event timeout", elapsed) + } +} + +// TestOpen_autoFallsBackOnWSStallAfterLifecycleEvent verifies that response.created +// alone does not commit the WebSocket attempt. If WS stalls before text/tool output, +// auto transport must fall back to HTTPS instead of leaving the client with no content. +// +//nolint:paralleltest +func TestOpen_autoFallsBackOnWSStallAfterLifecycleEvent(t *testing.T) { + prev := wsFirstEventTimeout + wsFirstEventTimeout = 200 * time.Millisecond + t.Cleanup(func() { wsFirstEventTimeout = prev }) + + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "http-ok", + ForcedWSFailure: refbackend.WSFailureStallAfterFirstEvent, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := New(Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: TransportAuto, + ExperimentalWebSocket: true, + }) + + start := time.Now() + es, err := be.Open(context.Background(), internalCodexCall(), internalCodexCand()) + if err != nil { + t.Fatalf("open: %v", err) + } + drainUntilEnd(t, es) + + if got := srv.LatestRequest().Transport; got != "https" { + t.Fatalf("transport = %q, want https (auto must fall back before committed output)", got) + } + if elapsed := time.Since(start); elapsed > 2*time.Second { + t.Fatalf("fallback took %v, expected to be bounded by the short first-event timeout", elapsed) + } +} + +// TestOpen_wsFirstEventTimeoutZeroDoesNotPanic ensures clearing the deadline +// path is exercised by a normal WS success after the timeout var was mutated. +// +//nolint:paralleltest +func TestOpen_wsFirstEventTimeoutZeroDoesNotPanic(t *testing.T) { + prev := wsFirstEventTimeout + wsFirstEventTimeout = 5 * time.Second + t.Cleanup(func() { wsFirstEventTimeout = prev }) + + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := New(Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: TransportWebSocket, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), internalCodexCall(), internalCodexCand()) + if err != nil { + t.Fatalf("open: %v", err) + } + drainUntilEnd(t, es) + if got := srv.LatestRequest().Transport; got != "websocket" { + t.Fatalf("transport = %q, want websocket", got) + } +} + +func TestWSStreamRecvContextCancelsAfterFirstEvent(t *testing.T) { + t.Parallel() + + srv := refbackend.New(refbackend.Config{ + Token: "sk-codex", + OutputText: "ws-stall", + ForcedWSFailure: refbackend.WSFailureStallAfterFirstEvent, + }) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + be := New(Config{ + BaseURL: ts.URL + "/backend-api/codex", + AccessToken: "sk-codex", + HTTPClient: ts.Client(), + Transport: TransportWebSocket, + ExperimentalWebSocket: true, + }) + es, err := be.Open(context.Background(), internalCodexCall(), internalCodexCand()) + if err != nil { + t.Fatalf("open: %v", err) + } + t.Cleanup(func() { _ = es.Close() }) + if _, err := es.Recv(context.Background()); err != nil { + t.Fatalf("first Recv: %v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + done := make(chan error, 1) + go func() { + _, err := es.Recv(ctx) + done <- err + }() + + select { + case err := <-done: + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Recv error = %v, want context deadline", err) + } + case <-time.After(time.Second): + t.Fatal("Recv did not return after context deadline") + } +} diff --git a/internal/plugins/backends/openaicodex/ws_internal_test.go b/internal/plugins/backends/openaicodex/ws_internal_test.go new file mode 100644 index 00000000..5ad59f56 --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws_internal_test.go @@ -0,0 +1,366 @@ +package openaicodex + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + gorillawebsocket "github.com/gorilla/websocket" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func TestWSEndpoint_schemeAndPath(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in, want string + }{ + {"https base path", "https://chatgpt.com/backend-api/codex", "wss://chatgpt.com/backend-api/codex/responses"}, + {"http localhost", "http://127.0.0.1:9/codex", "ws://127.0.0.1:9/codex/responses"}, + {"already responses path", "https://h/backend-api/codex/responses", "wss://h/backend-api/codex/responses"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := wsEndpoint(tc.in); got != tc.want { + t.Fatalf("wsEndpoint(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestPayloadToWSResponseCreate_addsTypeRemovesStream(t *testing.T) { + t.Parallel() + p := Payload{ + Model: "gpt-5.3-codex-spark", + Stream: true, + Store: false, + Input: []inputItem{textMessageItem{Type: "message", Role: "user", Content: "hi"}}, + } + raw, err := payloadToWSResponseCreate(p) + if err != nil { + t.Fatal(err) + } + var m map[string]json.RawMessage + if err := json.Unmarshal(raw, &m); err != nil { + t.Fatal(err) + } + if typ := strings.TrimSpace(string(m["type"])); typ != `"response.create"` { + t.Fatalf("type = %s, want response.create", typ) + } + if _, ok := m["stream"]; ok { + t.Fatalf("stream must be omitted from WS frame: %s", m["stream"]) + } + if _, ok := m["model"]; !ok { + t.Fatalf("model must be preserved: %#v", m) + } +} + +func TestResponsesEndpoint_normalizesPath(t *testing.T) { + t.Parallel() + cases := []struct { + name string + in, want string + }{ + {"base path", "https://chatgpt.com/backend-api/codex", "https://chatgpt.com/backend-api/codex/responses"}, + {"already responses path", "https://h/backend-api/codex/responses", "https://h/backend-api/codex/responses"}, + {"trim trailing slash", " http://127.0.0.1:9/codex/ ", "http://127.0.0.1:9/codex/responses"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := responsesEndpoint(tc.in); got != tc.want { + t.Fatalf("responsesEndpoint(%q) = %q, want %q", tc.in, got, tc.want) + } + }) + } +} + +func TestIsWSFallbackError(t *testing.T) { + t.Parallel() + if isWSFallbackError(context.Background(), nil) { + t.Fatal("nil error must not be fallback-eligible") + } + if isWSFallbackError(context.Background(), errors.New("dial failed")) { + t.Fatal("unclassified error must not be fallback-eligible") + } + if !isWSFallbackError(context.Background(), newWSTransportError(errors.New("dial failed"))) { + t.Fatal("classified transport error must be fallback-eligible") + } + if isWSFallbackError(context.Background(), fmt.Errorf("%s: marshal payload: %w", ID, errors.New("bad"))) { + t.Fatal("payload marshal error must not be fallback-eligible") + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if isWSFallbackError(ctx, newWSTransportError(errors.New("dial failed"))) { + t.Fatal("cancelled context must not be treated as fallback-eligible") + } +} + +func TestWsPreFirstEventFailure(t *testing.T) { + t.Parallel() + if got := wsPreFirstEventFailure(nil); got != nil { + t.Fatalf("nil = %v, want nil", got) + } + transport := newWSTransportError(errors.New("dial")) + if got := wsPreFirstEventFailure(transport); got != transport { + t.Fatal("existing transport error must pass through") + } + if !isWSFallbackError(context.Background(), wsPreFirstEventFailure(io.EOF)) { + t.Fatal("pre-first-event EOF must be fallback-eligible") + } + readErr := newWSStreamReadError(errors.New("i/o timeout")) + if !isWSFallbackError(context.Background(), wsPreFirstEventFailure(readErr)) { + t.Fatal("pre-first-event ws read error must be fallback-eligible") + } + mapperErr := fmt.Errorf("%s: malformed stream event: %w", ID, errors.New("bad json")) + if isWSFallbackError(context.Background(), wsPreFirstEventFailure(mapperErr)) { + t.Fatal("mapper error must not be fallback-eligible") + } +} + +func TestTransportCooldown_markAndExpiry(t *testing.T) { + t.Parallel() + now := time.Time{} + c := &transportCooldown{cooldown: 5 * time.Minute, now: func() time.Time { return now }} + if c.active() { + t.Fatal("cooldown must be inactive initially") + } + c.markFailed() + if !c.active() { + t.Fatal("cooldown must be active after failure") + } + if !c.until.Equal(now.Add(5 * time.Minute)) { + t.Fatalf("until = %v, want %v", c.until, now.Add(5*time.Minute)) + } + // After the window expires, auto mode may try WS again. + c.now = func() time.Time { return now.Add(6 * time.Minute) } + if c.active() { + t.Fatal("cooldown must expire after the window") + } +} + +func TestTransportCooldown_zeroCooldownUsesDefault(t *testing.T) { + t.Parallel() + c := newTransportCooldown(0) + if c.cooldown != DefaultWebSocketFallbackCooldown { + t.Fatalf("cooldown = %v, want default %v", c.cooldown, DefaultWebSocketFallbackCooldown) + } +} + +func TestIsWSFreePlanRejection(t *testing.T) { + t.Parallel() + p := newDowngradePolicy(Config{}) + frame := []byte(`{"type":"error","error":{"message":"gpt-5.5 is not available on free plan"}}`) + if !isWSFreePlanRejection(frame, p, "gpt-5.5") { + t.Fatal("expected WS free-plan rejection") + } + if isWSFreePlanRejection(frame, p, "gpt-5.4") { + t.Fatal("non-source model must not match") + } + if isWSFreePlanRejection([]byte(`{"type":"response.created"}`), p, "gpt-5.5") { + t.Fatal("non-error frame must not match") + } +} + +func TestWSSessionStoreAcquireHonorsContextWhileCheckedOut(t *testing.T) { + t.Parallel() + store := newWSSessionStore() + session := &wsSessionConn{ + key: wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "sess"}, + store: store, + sem: make(chan struct{}, 1), + } + if err := session.acquire(context.Background()); err != nil { + t.Fatalf("first acquire: %v", err) + } + store.sessions[session.key] = session + + ctx, cancel := context.WithTimeout(context.Background(), 25*time.Millisecond) + defer cancel() + got, resp, reused, err := store.acquire(ctx, http.DefaultClient, session.key.baseURL, &Config{AccessToken: "tok"}, "sess") + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("acquire error = %v, want deadline exceeded", err) + } + if got != nil || resp != nil || reused { + t.Fatalf("acquire returned session=%v resp=%v reused=%v on timeout", got, resp, reused) + } + session.release(true) +} + +func TestWSSessionStoreIdleTimerForgetsReusableSession(t *testing.T) { + t.Parallel() + store := newWSSessionStore() + store.idleTTL = 10 * time.Millisecond + key := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "sess"} + session := &wsSessionConn{ + key: key, + store: store, + sem: make(chan struct{}, 1), + } + if err := session.acquire(context.Background()); err != nil { + t.Fatalf("acquire: %v", err) + } + store.sessions[key] = session + session.release(false) + + deadline := time.After(time.Second) + for { + store.mu.Lock() + _, ok := store.sessions[key] + store.mu.Unlock() + if !ok { + return + } + select { + case <-deadline: + t.Fatal("idle websocket session was not evicted") + case <-time.After(5 * time.Millisecond): + } + } +} + +func TestWSSessionStorePruneToCapEvictsOldestReusableSession(t *testing.T) { + t.Parallel() + store := newWSSessionStore() + store.maxEntries = 2 + base := time.Unix(100, 0) + oldKey := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "old"} + midKey := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "mid"} + newKey := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "new"} + oldSession := &wsSessionConn{key: oldKey, store: store, sem: make(chan struct{}, 1), lastUsed: base} + midSession := &wsSessionConn{key: midKey, store: store, sem: make(chan struct{}, 1), lastUsed: base.Add(time.Second)} + newSession := &wsSessionConn{key: newKey, store: store, sem: make(chan struct{}, 1), lastUsed: base.Add(2 * time.Second)} + store.sessions[oldKey] = oldSession + store.sessions[midKey] = midSession + store.sessions[newKey] = newSession + + store.pruneToCapLocked(newSession) + + if _, ok := store.sessions[oldKey]; ok { + t.Fatal("oldest reusable session was not evicted") + } + if _, ok := store.sessions[midKey]; !ok { + t.Fatal("middle session was unexpectedly evicted") + } + if _, ok := store.sessions[newKey]; !ok { + t.Fatal("protected session was unexpectedly evicted") + } +} + +func TestWSSessionStoreCloseIdleClosesOrphanedSession(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + + store := newWSSessionStore() + key := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "sess"} + session := &wsSessionConn{ + key: key, + store: store, + sem: make(chan struct{}, 1), + conn: clientConn, + } + store.sessions[key] = &wsSessionConn{ + key: key, + store: store, + sem: make(chan struct{}, 1), + } + + store.closeIdle(key, session) + + if session.conn != nil { + t.Fatal("orphaned idle session conn was not closed") + } +} + +func TestWSStreamServerCloseIsNotReusable(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + if err := serverConn.WriteMessage(gorillawebsocket.CloseMessage, gorillawebsocket.FormatCloseMessage(gorillawebsocket.CloseNormalClosure, "")); err != nil { + t.Fatal(err) + } + var released bool + var closeConn bool + stream := newWSStream(clientConn, 0) + stream.release = func(close bool) { + released = true + closeConn = close + } + + _, err := stream.Recv(context.Background()) + if !errors.Is(err, io.EOF) { + t.Fatalf("Recv error = %v, want EOF", err) + } + if !released || !closeConn { + t.Fatalf("release = (%v, %v), want released closeConn=true", released, closeConn) + } +} + +func TestWSStreamSkipsEmptyTextFrames(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + if err := serverConn.WriteMessage(gorillawebsocket.TextMessage, []byte(" ")); err != nil { + t.Fatal(err) + } + if err := serverConn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.created","response":{"id":"resp_1"}}`)); err != nil { + t.Fatal(err) + } + stream := newWSStream(clientConn, 0) + stream.release = func(bool) {} + + ev, err := stream.Recv(context.Background()) + if err != nil { + t.Fatalf("Recv: %v", err) + } + if ev.Kind != lipapi.EventResponseStarted { + t.Fatalf("event kind = %q, want %q", ev.Kind, lipapi.EventResponseStarted) + } +} + +func TestWriteWSResponseCreateClosedConnIsFallbackEligible(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + _ = serverConn.Close() + _ = clientConn.Close() + + err := writeWSResponseCreate(context.Background(), clientConn, json.RawMessage(`{"type":"response.create"}`)) + if !isWSFallbackError(context.Background(), err) { + t.Fatalf("write error = %v, want WS fallback-eligible", err) + } +} + +func newTestWebSocketPair(t *testing.T) (*gorillawebsocket.Conn, *gorillawebsocket.Conn) { + t.Helper() + upgrader := gorillawebsocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + serverConnCh := make(chan *gorillawebsocket.Conn, 1) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + serverConnCh <- conn + })) + t.Cleanup(srv.Close) + clientConn, _, err := gorillawebsocket.DefaultDialer.Dial("ws"+strings.TrimPrefix(srv.URL, "http"), nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = clientConn.Close() }) + select { + case serverConn := <-serverConnCh: + return clientConn, serverConn + case <-time.After(time.Second): + t.Fatal("timeout waiting for websocket server connection") + return nil, nil + } +} diff --git a/internal/plugins/backends/protocols/openairesponsestream/mapper.go b/internal/plugins/backends/protocols/openairesponsestream/mapper.go index 044b09d2..c455f83a 100644 --- a/internal/plugins/backends/protocols/openairesponsestream/mapper.go +++ b/internal/plugins/backends/protocols/openairesponsestream/mapper.go @@ -22,6 +22,7 @@ type Mapper struct { toolCallStarted map[string]bool toolCallArgDeltas map[string]bool toolCallFinished map[string]bool + pendingToolArgs map[string][]string } func New(pending *stream.PendingEventQueue) *Mapper { @@ -30,6 +31,7 @@ func New(pending *stream.PendingEventQueue) *Mapper { toolCallStarted: make(map[string]bool), toolCallArgDeltas: make(map[string]bool), toolCallFinished: make(map[string]bool), + pendingToolArgs: make(map[string][]string), } } @@ -123,7 +125,10 @@ func (m *Mapper) ToolCallAdded(id, name string) error { if m.toolCallStarted[id] { return nil } - return m.emitToolCallStarted(id, name) + if err := m.emitToolCallStarted(id, name); err != nil { + return err + } + return m.flushPendingToolArgs(id) } func (m *Mapper) ToolCallArgsDelta(id, delta string) error { @@ -134,9 +139,8 @@ func (m *Mapper) ToolCallArgsDelta(id, delta string) error { return err } if !m.toolCallStarted[id] { - if err := m.emitToolCallStarted(id, ""); err != nil { - return err - } + m.pendingToolArgs[id] = append(m.pendingToolArgs[id], delta) + return nil } m.toolCallArgDeltas[id] = true return m.pending.Push(lipapi.Event{ @@ -158,6 +162,9 @@ func (m *Mapper) FinishToolCallArguments(id, name, arguments string) error { return err } } + if err := m.flushPendingToolArgs(id); err != nil { + return err + } if !m.toolCallArgDeltas[id] && arguments != "" { if err := m.pending.Push(lipapi.Event{ Kind: lipapi.EventToolCallArgsDelta, @@ -182,6 +189,9 @@ func (m *Mapper) EmitCompletedToolCall(id, name, arguments string) error { return err } } + if err := m.flushPendingToolArgs(id); err != nil { + return err + } if !m.toolCallArgDeltas[id] && arguments != "" { if err := m.pending.Push(lipapi.Event{ Kind: lipapi.EventToolCallArgsDelta, @@ -236,6 +246,53 @@ func (m *Mapper) emitToolCallStarted(id, name string) error { }) } +func (m *Mapper) flushPendingToolArgs(id string) error { + deltas := m.pendingToolArgs[id] + if len(deltas) == 0 { + return nil + } + delete(m.pendingToolArgs, id) + for _, delta := range deltas { + m.toolCallArgDeltas[id] = true + if err := m.pending.Push(lipapi.Event{ + Kind: lipapi.EventToolCallArgsDelta, + ToolCallID: id, + Delta: delta, + }); err != nil { + return err + } + } + return nil +} + +// RemapToolCallID consolidates tool-call state buffered under oldID onto newID. +// It is used when a tool call's real call_id is learned after argument deltas +// were already buffered under the provisional item-only ID, so pending args and +// started/arg-delta/finished flags all move onto the canonical ID instead of +// fragmenting into two tool calls. Remapping is a no-op when no state exists +// under oldID. +func (m *Mapper) RemapToolCallID(oldID, newID string) { + if oldID == "" || newID == "" || oldID == newID { + return + } + if deltas, ok := m.pendingToolArgs[oldID]; ok { + m.pendingToolArgs[newID] = append(m.pendingToolArgs[newID], deltas...) + delete(m.pendingToolArgs, oldID) + } + if m.toolCallStarted[oldID] { + m.toolCallStarted[newID] = true + delete(m.toolCallStarted, oldID) + } + if m.toolCallArgDeltas[oldID] { + m.toolCallArgDeltas[newID] = true + delete(m.toolCallArgDeltas, oldID) + } + if m.toolCallFinished[oldID] { + m.toolCallFinished[newID] = true + delete(m.toolCallFinished, oldID) + } +} + // EmitOutputMediaFromResponse maps assistant message media in a completed Responses payload. func EmitOutputMediaFromResponse(m *Mapper, resp responses.Response) error { for _, item := range resp.Output { diff --git a/internal/plugins/backends/protocols/openairesponsestream/mapper_test.go b/internal/plugins/backends/protocols/openairesponsestream/mapper_test.go index c5461e20..7044b990 100644 --- a/internal/plugins/backends/protocols/openairesponsestream/mapper_test.go +++ b/internal/plugins/backends/protocols/openairesponsestream/mapper_test.go @@ -316,3 +316,75 @@ func TestMapper_completedTextFallback_onlyWhenNoTextDeltas(t *testing.T) { t.Fatalf("texts after delta: %v", texts) } } + +func TestMapper_remapToolCallID_movesBufferedArgsOntoCanonicalID(t *testing.T) { + t.Parallel() + m, q := newTestMapper() + // Args arrive before the tool call is added, so they buffer under the + // provisional item-only ID. + if err := m.ToolCallArgsDelta("fc_late", `{"filePath":`); err != nil { + t.Fatal(err) + } + // Learning the real call_id remaps the buffered args onto it. + m.RemapToolCallID("fc_late", "call_late") + if err := m.ToolCallAdded("call_late", "read"); err != nil { + t.Fatal(err) + } + if err := m.FinishToolCallArguments("call_late", "read", `{"filePath":"x"}`); err != nil { + t.Fatal(err) + } + + var startedCount, argDeltaCount, finishedCount int + var startedID string + var args strings.Builder + for _, ev := range stream.DrainPending(q) { + switch ev.Kind { + case lipapi.EventToolCallStarted: + startedCount++ + startedID = ev.ToolCallID + case lipapi.EventToolCallArgsDelta: + argDeltaCount++ + args.WriteString(ev.Delta) + if ev.ToolCallID != "call_late" { + t.Fatalf("args delta under id %q, want call_late", ev.ToolCallID) + } + case lipapi.EventToolCallFinished: + finishedCount++ + if ev.ToolCallID != "call_late" { + t.Fatalf("finished under id %q, want call_late", ev.ToolCallID) + } + } + } + if startedCount != 1 || startedID != "call_late" { + t.Fatalf("started = %d / %q, want 1 / call_late", startedCount, startedID) + } + if argDeltaCount != 1 || args.String() != `{"filePath":` { + t.Fatalf("args = %d / %q, want 1 / {\"filePath\": (incremental preserved, full args suppressed)", argDeltaCount, args.String()) + } + if finishedCount != 1 { + t.Fatalf("finished = %d, want 1", finishedCount) + } +} + +func TestMapper_remapToolCallID_noOpWhenIDsEqualOrEmpty(t *testing.T) { + t.Parallel() + m, q := newTestMapper() + if err := m.ToolCallArgsDelta("fc_1", `{"x":1}`); err != nil { + t.Fatal(err) + } + m.RemapToolCallID("fc_1", "fc_1") // equal -> no-op + m.RemapToolCallID("", "call_1") // empty old -> no-op + m.RemapToolCallID("fc_1", "") // empty new -> no-op + if err := m.ToolCallAdded("fc_1", "get"); err != nil { + t.Fatal(err) + } + var args strings.Builder + for _, ev := range stream.DrainPending(q) { + if ev.Kind == lipapi.EventToolCallArgsDelta { + args.WriteString(ev.Delta) + } + } + if args.String() != `{"x":1}` { + t.Fatalf("args = %q, want original buffered delta preserved after no-op remaps", args.String()) + } +} diff --git a/internal/plugins/features/codexclientcompat/compat.go b/internal/plugins/features/codexclientcompat/compat.go index 33d92c30..937d1be8 100644 --- a/internal/plugins/features/codexclientcompat/compat.go +++ b/internal/plugins/features/codexclientcompat/compat.go @@ -15,17 +15,16 @@ const ( droidBridgeMarker = "Factory Droid compatibility mode" hermesBridgeMarker = "Hermes Agent compatibility mode" - extAgentKey = "agent" - extUserAgentKey = "user_agent" - extCodexAgentKey = "openai_codex.agent" - extHeadersKey = "headers" - // ponytail: mirrors openaicodex.ExtToolStrict; local const avoids feature→backend import. - extCodexToolStrictKey = "openai_codex.tool_strict" + extAgentKey = "agent" + extUserAgentKey = "user_agent" + extCodexAgentKey = "openai_codex.agent" + extHeadersKey = "headers" + extCodexToolStrictKey = "openai_codex.tool_strict" + extCodexIgnoreUnsupportedGenParamsKey = "openai_codex.ignore_unsupported_gen_params" // hermesIdentitySentence is the exact upstream Hermes Agent identity sentence. hermesIdentitySentence = "You are Hermes Agent, an intelligent AI assistant created by Nous Research." - // ponytail: mirrors openaicodex default when instructions empty so bridge appends after base Codex prompt. codexDefaultInstruction = "You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer." ) @@ -80,8 +79,8 @@ var compatBridges = []compatBridge{ matchesAgent: openCodeAgentMatch, matchesPrompt: openCodePromptMatch, filter: isOpenCodeHarnessText, - build: func(*lipapi.Call) string { return buildOpenCodeBridge() }, - beforeApply: convertOrphanedToolResults, + build: func(call *lipapi.Call) string { return buildOpenCodeBridge(len(call.Tools) > 0) }, + beforeApply: applyOpenCodeToolHistoryCompat, }, { marker: piBridgeMarker, @@ -114,15 +113,19 @@ func ApplyCompat(call *lipapi.Call) { in := detectCompatInput(call) bridge := selectCompatBridge(in) if bridge == nil { - return + bridge = fallbackCompatBridge(call) + if bridge == nil { + return + } } + applyIgnoreUnsupportedGenParams(call) hasTools := len(call.Tools) > 0 call.Messages = filterHarnessMessages(call.Messages, bridge.filter) call.Instructions = filterHarnessMessages(call.Instructions, bridge.filter) if bridge.beforeApply != nil { bridge.beforeApply(call) } - if !hasTools { + if !hasTools && bridge.marker != openCodeBridgeMarker { return } block := bridge.build(call) @@ -144,6 +147,82 @@ func selectCompatBridge(in compatInput) *compatBridge { return nil } +func fallbackCompatBridge(call *lipapi.Call) *compatBridge { + if call == nil || len(call.Tools) > 0 || !hasStructuredToolTranscript(call.Messages) { + return nil + } + for i := range compatBridges { + if compatBridges[i].marker == openCodeBridgeMarker { + return &compatBridges[i] + } + } + return nil +} + +func applyOpenCodeToolHistoryCompat(call *lipapi.Call) { + convertOrphanedToolResults(call) +} + +func hasStructuredToolTranscript(msgs []lipapi.Message) bool { + for _, m := range msgs { + if m.Role == lipapi.RoleTool { + for _, p := range m.Parts { + if p.Kind == lipapi.PartToolResult { + return true + } + } + } + if m.Role != lipapi.RoleAssistant { + continue + } + for _, p := range m.Parts { + if p.Kind != lipapi.PartJSON { + continue + } + if isFunctionCallPart(p) { + return true + } + } + } + return false +} + +func isFunctionCallPart(p lipapi.Part) bool { + if len(p.Content) == 0 { + return false + } + var fc struct { + Type string `json:"type"` + CallID string `json:"call_id"` + ID string `json:"id"` + Name string `json:"name"` + Function *struct { + Name string `json:"name"` + } `json:"function"` + } + if json.Unmarshal(p.Content, &fc) != nil { + return false + } + if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { + return false + } + id := firstNonEmpty(fc.CallID, fc.ID) + name := strings.TrimSpace(fc.Name) + if name == "" && fc.Function != nil { + name = strings.TrimSpace(fc.Function.Name) + } + return strings.TrimSpace(id) != "" && name != "" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + func detectCompatInput(call *lipapi.Call) compatInput { return compatInput{ agents: collectAgentCandidates(call), @@ -442,7 +521,9 @@ func collectKnownToolCallIDs(msgs []lipapi.Message) map[string]struct{} { if json.Unmarshal(p.Content, &fc) != nil { continue } - if !strings.EqualFold(fc.Type, "function_call") { + // Accept Responses-style ("function_call") and Chat Completions-style + // ("function") assistant tool calls so matching tool results are preserved. + if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { continue } id := strings.TrimSpace(fc.CallID) @@ -457,6 +538,30 @@ func collectKnownToolCallIDs(msgs []lipapi.Message) map[string]struct{} { return known } +func argumentText(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + if raw[0] == '"' { + var s string + if json.Unmarshal(raw, &s) == nil { + return s + } + } + return string(raw) +} + +func messagePartText(p lipapi.Part) string { + switch p.Kind { + case lipapi.PartText: + return p.Text + case lipapi.PartToolResult, lipapi.PartJSON: + return string(p.Content) + default: + return string(p.Kind) + } +} + func convertOrphanedToolResult(p lipapi.Part) lipapi.Message { rendered := string(p.Content) if len(p.Content) == 0 { @@ -472,20 +577,41 @@ func convertOrphanedToolResult(p lipapi.Part) lipapi.Message { } } -func buildOpenCodeBridge() string { - return openCodeBridgeMarker + ":\n" + - "- Prefer the available client shell tool when command execution is needed.\n" + +func buildOpenCodeBridge(hasTools bool) string { + var b strings.Builder + b.WriteString(openCodeBridgeMarker) + b.WriteString(":\n") + if hasTools { + // Keep this guidance generic. OpenCode tool names and schemas vary by + // installation, plugin, and session; the structured tool list is the only + // authoritative source of callable names. Duplicating names in prose makes + // random session-specific tools look universal and can bias the model toward + // tools the current request did not actually expose. + b.WriteString("- Prefer the available client shell tool when command execution is needed.\n") + } else { + b.WriteString("- No callable client tools are available in this request. Do not attempt tool calls; respond in plain text or ask the user/client to provide tools.\n") + } + b.WriteString("- Never emit textual tool-call syntax such as `to=functions.` or JSON tool calls in assistant content; use structured tool calls only when tools are available.\n") + if !hasTools { + // No tools are exposed, so do not append criticalInstruction("OpenCode"): + // it tells the model to use agent-provided tools, contradicting the + // "no callable client tools" guidance above and risking spurious tool calls. + return b.String() + } + b.WriteString( "- For bash-style tools, arguments MUST be a JSON object with string " + - "`command` and string `description`.\n" + - "- Bash-style tools MAY include numeric `timeout` in milliseconds " + - "and string `workdir` when the client schema exposes them.\n" + - "- Never emit array-valued `command` arguments for shell execution.\n" + - "- Do not use `apply_patch`; use the client's native file editing tools instead.\n" + - "- Do not use `update_plan` or `read_plan`; use the client's task tools instead.\n" + - "- If you need a working directory, prefer `workdir` over `cd` commands " + - "or embedding cwd text in `description`.\n" + - "\n" + - criticalInstruction("OpenCode") + "`command` and string `description`.\n" + + "- Bash-style tools MAY include numeric `timeout` in milliseconds " + + "and string `workdir` when the client schema exposes them.\n" + + "- Never emit array-valued `command` arguments for shell execution.\n" + + "- Do not use `apply_patch`; use the client's native file editing tools instead.\n" + + "- Do not use `update_plan` or `read_plan`; use the client's task tools instead.\n" + + "- If you need a working directory, prefer `workdir` over `cd` commands " + + "or embedding cwd text in `description`.\n" + + "\n" + + criticalInstruction("OpenCode"), + ) + return b.String() } func buildPiBridge() string { @@ -553,6 +679,16 @@ func applyHermesToolStrict(call *lipapi.Call) { call.Extensions[extCodexToolStrictKey] = json.RawMessage("false") } +func applyIgnoreUnsupportedGenParams(call *lipapi.Call) { + if call.Extensions == nil { + call.Extensions = map[string]json.RawMessage{} + } + if _, ok := call.Extensions[extCodexIgnoreUnsupportedGenParamsKey]; ok { + return + } + call.Extensions[extCodexIgnoreUnsupportedGenParamsKey] = json.RawMessage("true") +} + func sortedNativeDroidTools() []string { out := make([]string, 0, len(droidNativeToolNames)) for name := range droidNativeToolNames { diff --git a/internal/plugins/features/codexclientcompat/plugin.go b/internal/plugins/features/codexclientcompat/plugin.go index e672eca1..4d1ca34d 100644 --- a/internal/plugins/features/codexclientcompat/plugin.go +++ b/internal/plugins/features/codexclientcompat/plugin.go @@ -9,7 +9,7 @@ import ( const ( defaultOrder = 50 - targetBackendID = "openai-codex" // ponytail: mirrors openaicodex.ID; local const avoids feature→backend import. + targetBackendID = "openai-codex" // ponytail: mirrors openaicodex.ID; local const avoids feature->backend import. ) type requestPartHook struct { @@ -34,6 +34,10 @@ func (h requestPartHook) HandleRequestParts(_ context.Context, call *lipapi.Call if meta.BackendID != targetBackendID { return nil } + if call == nil { + return nil + } + applyIgnoreUnsupportedGenParams(call) ApplyCompat(call) return nil } diff --git a/internal/plugins/features/codexclientcompat/plugin_test.go b/internal/plugins/features/codexclientcompat/plugin_test.go index 399ecf21..7b5e0bb4 100644 --- a/internal/plugins/features/codexclientcompat/plugin_test.go +++ b/internal/plugins/features/codexclientcompat/plugin_test.go @@ -234,6 +234,10 @@ func TestRequestPartHook_mutatesWhenBackendIsOpenAICodex(t *testing.T) { func TestApplyOpenCodeCompat_dedupBridgeOrphanToolOutput(t *testing.T) { t.Parallel() + // The bridge is intentionally generic: OpenCode deployments can expose any + // combination of built-in, MCP, or plugin tools. Prompt text must not duplicate + // request-specific names such as "bash" because the structured tool schema is + // the real availability contract and prose can outlive the current request. call := &lipapi.Call{ Instructions: []lipapi.Message{{ Role: lipapi.RoleSystem, @@ -290,6 +294,7 @@ func TestApplyOpenCodeCompat_dedupBridgeOrphanToolOutput(t *testing.T) { for _, want := range []string{ "string `command` and string `description`", "Never emit array-valued `command`", + "Never emit textual tool-call syntax", "Do not use `apply_patch`", "Do not use `update_plan` or `read_plan`", "prefer `workdir`", @@ -299,6 +304,267 @@ func TestApplyOpenCodeCompat_dedupBridgeOrphanToolOutput(t *testing.T) { t.Fatalf("bridge missing %q", want) } } + for _, unwanted := range []string{"Use only tool names that are available in this request", "`bash`"} { + if strings.Contains(instructions, unwanted) { + t.Fatalf("OpenCode bridge must not duplicate request-specific tool names: %q", instructions) + } + } +} + +func TestApplyOpenCodeCompat_keepsToolResultForChatCompletionsToolCall(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleSystem, Parts: []lipapi.Part{lipapi.TextPart("OpenCode tool environment prompt for bash and edit tools")}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("run it")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"call_abc","type":"function","function":{"name":"bash","arguments":"{\"command\":\"echo pong\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: json.RawMessage(`{"status":"ok"}`), + }}}, + }, + Tools: []lipapi.ToolDef{{Name: "bash"}}, + Extensions: map[string]json.RawMessage{ + extAgentKey: json.RawMessage(`"opencode"`), + }, + } + runHook(t, call, targetBackendID) + + preserved := false + for _, m := range call.Messages { + if m.Role != lipapi.RoleTool { + continue + } + for _, p := range m.Parts { + if p.Kind == lipapi.PartToolResult && p.ToolCallID == "call_abc" { + preserved = true + } + } + } + if !preserved { + t.Fatalf("expected tool result for call_abc preserved as RoleTool: %#v", call.Messages) + } + for _, m := range call.Messages { + if strings.Contains(messageText(m), "Prior tool output") { + t.Fatalf("tool result matching a chat-completions tool call must not be treated as orphaned: %#v", call.Messages) + } + } +} + +func TestApplyOpenCodeCompat_preservesMatchedToolProtocolWithTools(t *testing.T) { + t.Parallel() + // When tools are present, even old matched tool calls/results must remain + // structured. WebSocket continuation records prior output items and then sends + // only the delta input on the next turn; flattening matched history here would + // destroy the protocol lineage needed for previous_response_id continuation. + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("old request")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"old_call","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"old\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "old_call", + Content: json.RawMessage(`{"old":true}`), + }}}, + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("current request")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"active_call","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"active\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "active_call", + Content: json.RawMessage(`{"active":true}`), + }}}, + }, + Tools: []lipapi.ToolDef{{Name: "grep"}}, + Extensions: map[string]json.RawMessage{ + extAgentKey: json.RawMessage(`"opencode"`), + }, + } + runHook(t, call, targetBackendID) + + payload, err := openaicodex.PayloadForCall(call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, openaicodex.Config{}) + if err != nil { + t.Fatal(err) + } + raw, _ := json.Marshal(payload) + payloadJSON := string(raw) + if got := strings.Count(payloadJSON, `"type":"function_call",`); got != 2 { + t.Fatalf("matched function calls should remain structured, got %d: %s", got, payloadJSON) + } + if got := strings.Count(payloadJSON, `"type":"function_call_output"`); got != 2 { + t.Fatalf("matched function outputs should remain structured, got %d: %s", got, payloadJSON) + } + if !strings.Contains(payloadJSON, `"call_id":"old_call"`) { + t.Fatalf("matched stale tool call should remain protocol-shaped for WS continuation: %s", payloadJSON) + } + if !strings.Contains(payloadJSON, `"call_id":"active_call"`) { + t.Fatalf("active tool call must remain protocol-shaped: %s", payloadJSON) + } +} + +func TestApplyOpenCodeCompat_flattensNoToolsToolHistoryAsConversation(t *testing.T) { + t.Parallel() + // This is the no-tools counterpart to the structured-history test above. The + // client may resend a transcript containing tool history while exposing zero + // callable tools. In that mode the safest universal representation is ordinary + // conversation text: it preserves context, avoids backend protocol mismatch, and + // prevents the model from continuing with raw textual tool-call syntax. + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("inspect logs")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"call_abc","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"error\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: json.RawMessage(`{"matches":100}`), + }}}, + }, + Extensions: map[string]json.RawMessage{ + extAgentKey: json.RawMessage(`"opencode"`), + }, + } + runHook(t, call, targetBackendID) + + for _, m := range call.Messages { + if strings.Contains(messageText(m), "Prior tool output") && m.Role != lipapi.RoleUser { + t.Fatalf("flattened tool output must remain conversation history, got role %q in %#v", m.Role, call.Messages) + } + } + payload, err := openaicodex.PayloadForCall(call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, openaicodex.Config{}) + if err != nil { + t.Fatal(err) + } + if strings.Contains(payload.Instructions, "Prior tool output") { + t.Fatalf("tool history must not be folded into Codex instructions: %q", payload.Instructions) + } + raw, _ := json.Marshal(payload) + if strings.Contains(string(raw), `"type":"function_call"`) || strings.Contains(string(raw), `"type":"function_call_output"`) { + t.Fatalf("no-tools history must not remain protocol-shaped: %s", raw) + } + if !strings.Contains(string(raw), "Prior assistant tool call") || !strings.Contains(string(raw), "Prior tool output") { + t.Fatalf("no-tools history should be rendered as conversation text: %s", raw) + } +} + +func TestApplyOpenCodeCompat_noToolsAddsTextualToolCallGuard(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("inspect logs")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"call_abc","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"error\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: json.RawMessage(`{"matches":100}`), + }}}, + }, + Extensions: map[string]json.RawMessage{ + extAgentKey: json.RawMessage(`"opencode"`), + }, + } + runHook(t, call, targetBackendID) + + instructions := joinInstructionText(call.Instructions) + for _, want := range []string{ + openCodeBridgeMarker, + "No callable client tools are available", + "Never emit textual tool-call syntax", + "to=functions.", + } { + if !strings.Contains(instructions, want) { + t.Fatalf("instructions missing %q: %q", want, instructions) + } + } + if call.Messages[0].Role != lipapi.RoleSystem || !strings.Contains(messageText(call.Messages[0]), openCodeBridgeMarker) { + t.Fatalf("expected bridge system message first: %#v", call.Messages[0]) + } +} + +func TestApplyCompat_noMarkerNoToolsToolHistoryAddsTextualToolCallGuard(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{ + {Role: lipapi.RoleUser, Parts: []lipapi.Part{lipapi.TextPart("inspect logs")}}, + {Role: lipapi.RoleAssistant, Parts: []lipapi.Part{{ + Kind: lipapi.PartJSON, + Content: json.RawMessage(`{"id":"call_abc","type":"function","function":{"name":"grep","arguments":"{\"pattern\":\"error\"}"}}`), + }}}, + {Role: lipapi.RoleTool, Parts: []lipapi.Part{{ + Kind: lipapi.PartToolResult, + ToolCallID: "call_abc", + Content: json.RawMessage(`{"matches":100}`), + }}}, + }, + } + runHook(t, call, targetBackendID) + + instructions := joinInstructionText(call.Instructions) + for _, want := range []string{ + openCodeBridgeMarker, + "No callable client tools are available", + "Never emit textual tool-call syntax", + } { + if !strings.Contains(instructions, want) { + t.Fatalf("instructions missing %q: %q", want, instructions) + } + } +} + +func TestApplyCompat_noMarkerNoToolsPlainChatDoesNotAddOpenCodeGuard(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Instructions: []lipapi.Message{{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart("Base instructions")}, + }}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hello")}, + }}, + } + runHook(t, call, targetBackendID) + if instructions := joinInstructionText(call.Instructions); strings.Contains(instructions, openCodeBridgeMarker) { + t.Fatalf("plain no-tools chat must not get OpenCode guard: %q", instructions) + } } func TestApplyPiCompat_dedupBridgeNotDuplicated(t *testing.T) { @@ -401,7 +667,7 @@ func TestRequestPartHook_openCodeCompatPayloadShape(t *testing.T) { } runHook(t, &call, targetBackendID) payload, err := openaicodex.PayloadForCall(&call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, openaicodex.Config{}) if err != nil { t.Fatal(err) @@ -411,6 +677,75 @@ func TestRequestPartHook_openCodeCompatPayloadShape(t *testing.T) { } } +func TestApplyCompat_setsIgnoreUnsupportedGenParamsExt(t *testing.T) { + t.Parallel() + maxTok := 512 + call := &lipapi.Call{ + Extensions: map[string]json.RawMessage{extAgentKey: json.RawMessage(`"opencode"`)}, + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, + } + runHook(t, call, targetBackendID) + raw, ok := call.Extensions[extCodexIgnoreUnsupportedGenParamsKey] + if !ok { + t.Fatal("expected ignore_unsupported_gen_params extension") + } + var ignore bool + if err := json.Unmarshal(raw, &ignore); err != nil || !ignore { + t.Fatalf("ignore_unsupported_gen_params = %s, want true", raw) + } + if _, err := openaicodex.PayloadForCall(call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, + }, openaicodex.Config{}); err != nil { + t.Fatalf("payload with compat ext: %v", err) + } +} + +func TestRequestPartHook_codexBackendSetsIgnoreUnsupportedGenParamsWithoutClientMarker(t *testing.T) { + t.Parallel() + maxTok := 512 + call := &lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("summarize prior context")}, + }}, + Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, + } + runHook(t, call, targetBackendID) + raw, ok := call.Extensions[extCodexIgnoreUnsupportedGenParamsKey] + if !ok { + t.Fatal("expected ignore_unsupported_gen_params extension") + } + var ignore bool + if err := json.Unmarshal(raw, &ignore); err != nil || !ignore { + t.Fatalf("ignore_unsupported_gen_params = %s, want true", raw) + } + if _, err := openaicodex.PayloadForCall(call, routing.AttemptCandidate{ + Primary: routing.Primary{Model: "gpt-5.4-mini"}, + }, openaicodex.Config{}); err != nil { + t.Fatalf("payload with codex compat ext: %v", err) + } +} + +func TestRequestPartHook_nonCodexBackendDoesNotSetIgnoreUnsupportedGenParamsWithoutClientMarker(t *testing.T) { + t.Parallel() + maxTok := 512 + call := &lipapi.Call{ + Messages: []lipapi.Message{{ + Role: lipapi.RoleUser, + Parts: []lipapi.Part{lipapi.TextPart("hi")}, + }}, + Options: lipapi.GenerationOptions{MaxOutputTokens: &maxTok}, + } + runHook(t, call, "openai-responses") + if _, ok := call.Extensions[extCodexIgnoreUnsupportedGenParamsKey]; ok { + t.Fatal("did not expect ignore_unsupported_gen_params extension for non-Codex backend") + } +} + func TestDetectHermesFromExtensionAgent(t *testing.T) { t.Parallel() in := detectCompatInput(&lipapi.Call{ @@ -571,7 +906,7 @@ func TestApplyHermesCompat_payloadShapeIncludesBridgeAndToolStrictFalse(t *testi } payload, err := openaicodex.PayloadForCall(call, routing.AttemptCandidate{ - Primary: routing.Primary{Model: "gpt-5.3-codex"}, + Primary: routing.Primary{Model: "gpt-5.3-codex-spark"}, }, openaicodex.Config{}) if err != nil { t.Fatal(err) diff --git a/internal/plugins/frontends/anthropic/handler.go b/internal/plugins/frontends/anthropic/handler.go index 004ca63c..60e3a4f8 100644 --- a/internal/plugins/frontends/anthropic/handler.go +++ b/internal/plugins/frontends/anthropic/handler.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "strings" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/decodeqos" @@ -13,6 +14,7 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/jsonguard" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/reqbody" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/streamdebug" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk/traffic" @@ -31,12 +33,14 @@ type Handler struct { Exec lipsdk.ExecutorView // DefaultRouteSelector is used when HeaderRouteSelector is absent. DefaultRouteSelector string - MaxRequestBodyBytes int64 - Log *slog.Logger - TrafficPorts traffic.PortBundle - DecodeLimiter *decodeqos.Limiter - PreRequestKeepalive lipsdk.FrontendKeepaliveConfig - Config Config + // RoutePrefixes are backend route-selector prefixes accepted from body model. + RoutePrefixes routeselect.PrefixSet + MaxRequestBodyBytes int64 + Log *slog.Logger + TrafficPorts traffic.PortBundle + DecodeLimiter *decodeqos.Limiter + PreRequestKeepalive lipsdk.FrontendKeepaliveConfig + Config Config } func (h *Handler) maxBodyLimit() int64 { @@ -114,7 +118,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if sel == "" { - sel = routeselect.FromModelOrDefault(body, h.DefaultRouteSelector) + sel = h.RoutePrefixes.FromModelOrDefault(body, h.DefaultRouteSelector) } decoded, err := DecodeMessageRequest(body, DecodeOptions{ RouteSelector: sel, @@ -123,9 +127,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) releaseDecode() if err != nil { - if h.Log != nil { - diag.LogError(ctx, h.Log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + log := h.Log + if log == nil { + log = slog.Default() } + diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON", "invalid_request_error")) return } @@ -145,6 +152,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { SessionID: call.Session.CorrelationID(), }, "http", ct, body) + streamdebug.LogCall(ctx, h.Log, ID, call, decoded.Stream, len(body), sel) + executeStart := time.Now() es, err := h.execute(ctx, w, call, decoded.Stream) if err != nil { out := execerr.ClassifyExecute(err) @@ -168,7 +177,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + streamdebug.LogExecuteOpened(ctx, h.Log, ID, call, executeStart) ctx = diag.EnsureCallDiag(ctx, traceID, call.Session.ALegID) + es = streamdebug.Wrap(ctx, h.Log, ID, call, es, executeStart) opts := EncodeOptions{ MessageID: "msg_" + diag.StableCallToken(call), diff --git a/internal/plugins/frontends/anthropic/integration_test.go b/internal/plugins/frontends/anthropic/integration_test.go index dd9b45ff..f5f96644 100644 --- a/internal/plugins/frontends/anthropic/integration_test.go +++ b/internal/plugins/frontends/anthropic/integration_test.go @@ -13,6 +13,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/anthropic" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" refcli "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/anthropicmessages" "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/refclienttest" "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" @@ -360,7 +361,7 @@ func TestIntegration_routeHeaderOverridesDefault(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/messages", h) srv := httptest.NewServer(mux) @@ -396,7 +397,7 @@ func TestIntegration_modelRouteSelectorUsedWhenHeaderAbsent(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/messages", h) srv := httptest.NewServer(mux) diff --git a/internal/plugins/frontends/gemini/handler.go b/internal/plugins/frontends/gemini/handler.go index 69df7572..70745574 100644 --- a/internal/plugins/frontends/gemini/handler.go +++ b/internal/plugins/frontends/gemini/handler.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "strings" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/decodeqos" @@ -13,6 +14,7 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/jsonguard" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/reqbody" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/streamdebug" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk/traffic" @@ -29,12 +31,14 @@ type Handler struct { Exec lipsdk.ExecutorView // DefaultRouteSelector is used when HeaderRouteSelector is absent. DefaultRouteSelector string - MaxRequestBodyBytes int64 - Log *slog.Logger - TrafficPorts traffic.PortBundle - DecodeLimiter *decodeqos.Limiter - PreRequestKeepalive lipsdk.FrontendKeepaliveConfig - Config Config + // RoutePrefixes are backend route-selector prefixes accepted from URL model. + RoutePrefixes routeselect.PrefixSet + MaxRequestBodyBytes int64 + Log *slog.Logger + TrafficPorts traffic.PortBundle + DecodeLimiter *decodeqos.Limiter + PreRequestKeepalive lipsdk.FrontendKeepaliveConfig + Config Config } func (h *Handler) maxBodyLimit() int64 { @@ -111,7 +115,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if sel == "" { - sel = routeselect.InlineOrDefault(model, h.DefaultRouteSelector) + sel = h.RoutePrefixes.InlineOrDefault(model, h.DefaultRouteSelector) } decoded, err := DecodeGenerateContentRequest(body, DecodeOptions{ RouteSelector: sel, @@ -121,9 +125,12 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) releaseDecode() if err != nil { - if h.Log != nil { - diag.LogError(ctx, h.Log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + log := h.Log + if log == nil { + log = slog.Default() } + diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON")) return } @@ -143,6 +150,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { SessionID: call.Session.CorrelationID(), }, "http", ct, body) + streamdebug.LogCall(ctx, h.Log, ID, call, stream, len(body), sel) + executeStart := time.Now() es, err := h.execute(ctx, w, call, stream) if err != nil { out := execerr.ClassifyExecute(err) @@ -157,7 +166,9 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - ctx = diag.EnsureCallDiag(ctx, traceID, call.Session.ALegID) + streamdebug.LogExecuteOpened(ctx, h.Log, ID, call, executeStart) + ctx = diag.EnsureCallDiag(ctx, traceID, strings.TrimSpace(call.Session.ALegID)) + es = streamdebug.Wrap(ctx, h.Log, ID, call, es, executeStart) opts := EncodeOptions{ExposeLipUsageExtensions: h.Config.ExposeLipUsageExtensions} if stream { diff --git a/internal/plugins/frontends/gemini/integration_test.go b/internal/plugins/frontends/gemini/integration_test.go index 1ad27725..36134b41 100644 --- a/internal/plugins/frontends/gemini/integration_test.go +++ b/internal/plugins/frontends/gemini/integration_test.go @@ -11,6 +11,7 @@ import ( "testing" front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/gemini" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" refcli "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/gemini" "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/refclienttest" "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" @@ -346,7 +347,7 @@ func TestIntegration_routeHeaderOverridesDefault(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/", h) srv := httptest.NewServer(mux) @@ -383,7 +384,7 @@ func TestIntegration_modelRouteSelectorUsedWhenHeaderAbsent(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/", h) srv := httptest.NewServer(mux) diff --git a/internal/plugins/frontends/openailegacy/codex_body_routing_test.go b/internal/plugins/frontends/openailegacy/codex_body_routing_test.go new file mode 100644 index 00000000..a0203123 --- /dev/null +++ b/internal/plugins/frontends/openailegacy/codex_body_routing_test.go @@ -0,0 +1,79 @@ +package openailegacy_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/b2bua" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/execbackend" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/hooks" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/runtime" + front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openailegacy" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" + "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// TestIntegration_openaiCodexURIReasoningEffortOverridesBody proves that a reasoning_effort +// URI param on the model selector OVERRIDES the Chat Completions body's reasoning_effort +// field. URI params are explicit routing directives and take precedence over per-request +// body settings. +func TestIntegration_openaiCodexURIReasoningEffortOverridesBody(t *testing.T) { + t.Parallel() + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + var captured lipapi.Call + ex := &runtime.Executor{ + Store: st, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(42), + Backends: map[string]execbackend.Backend{ + "openai-codex": { + Caps: lipapi.NewBackendCaps(lipapi.CapabilityStreaming, lipapi.CapabilityReasoning, lipapi.CapabilityTools), + Open: func(_ context.Context, call lipapi.Call, _ routing.AttemptCandidate) (lipapi.ManagedEventStream, error) { + captured = call + return lipapi.NewFixedEventStream([]lipapi.Event{ + {Kind: lipapi.EventResponseStarted}, + {Kind: lipapi.EventMessageStarted}, + {Kind: lipapi.EventTextDelta, Delta: "ok"}, + {Kind: lipapi.EventResponseFinished}, + }), nil + }, + }, + }, + } + testkit.WireConformanceExecutorSecureSession(t, ex) + + h := &front.Handler{Exec: ex, DefaultRouteSelector: "openai-codex:gpt-5.5", RoutePrefixes: routeselect.NewPrefixSet([]string{"openai-codex"})} + mux := http.NewServeMux() + mux.Handle("/v1/chat/completions", h) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, srv.URL+"/v1/chat/completions", + strings.NewReader(`{"model":"openai-codex:openai/gpt-5.4-mini?reasoning_effort=xhigh","reasoning_effort":"medium","stream":false,"messages":[{"role":"user","content":"hi"}]}`)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + res, err := testkit.IntegrationHTTPClient(nil).Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = res.Body.Close() }() + if res.StatusCode != http.StatusOK { + b, _ := io.ReadAll(res.Body) + t.Fatalf("status %d body %s", res.StatusCode, string(b)) + } + + if captured.Options.ReasoningEffort != "xhigh" { + t.Fatalf("call.Options.ReasoningEffort %q, want %q (URI param must override body field)", captured.Options.ReasoningEffort, "xhigh") + } +} diff --git a/internal/plugins/frontends/openailegacy/decode.go b/internal/plugins/frontends/openailegacy/decode.go index ba8ccb92..5981c5cd 100644 --- a/internal/plugins/frontends/openailegacy/decode.go +++ b/internal/plugins/frontends/openailegacy/decode.go @@ -45,6 +45,7 @@ type wireCreate struct { TopP *float64 `json:"top_p"` MaxTokens *int `json:"max_tokens"` ParallelToolCalls *bool `json:"parallel_tool_calls"` + ReasoningEffort string `json:"reasoning_effort"` StreamOptions json.RawMessage `json:"stream_options"` Metadata map[string]string `json:"metadata,omitempty"` } @@ -55,9 +56,14 @@ var legacyKnownBodyKeys = map[string]bool{ "parallel_tool_calls": true, "stream_options": true, "metadata": true, "max_completion_tokens": true, "n": true, "stop": true, "presence_penalty": true, "frequency_penalty": true, "logit_bias": true, "logprobs": true, "top_logprobs": true, - "seed": true, "suffix": true, + "seed": true, "suffix": true, "reasoning_effort": true, } +var ( + errEmptyAssistantMessage = errors.New("empty assistant message") + errEmptyChatContent = errors.New("message content string is empty") +) + // DecodeChatRequest maps a Chat Completions JSON body into a canonical call. func DecodeChatRequest(body []byte, opts DecodeOptions) (*DecodedChat, error) { sel := strings.TrimSpace(opts.RouteSelector) @@ -134,6 +140,7 @@ func DecodeChatRequest(body []byte, opts DecodeOptions) (*DecodedChat, error) { TopP: w.TopP, MaxOutputTokens: w.MaxTokens, ParallelToolCalls: w.ParallelToolCalls, + ReasoningEffort: strings.TrimSpace(w.ReasoningEffort), }, } if len(w.Metadata) > 0 { @@ -150,6 +157,9 @@ func parseMessages(raw []json.RawMessage) ([]lipapi.Message, error) { for i, it := range raw { m, err := parseMessage(it) if err != nil { + if errors.Is(err, errEmptyAssistantMessage) { + continue + } return nil, fmt.Errorf("openailegacy: messages[%d]: %w", i, err) } out = append(out, m) @@ -216,10 +226,14 @@ func parseAssistantParts(content, toolCalls, functionCall json.RawMessage) ([]li if jsonpresence.IsPresentNonNullJSON(content) { cp, err := parseChatContent(content) - if err != nil { + switch { + case err == nil: + contentParts = cp + case errors.Is(err, errEmptyChatContent): + // treat empty content as absent for tool-call-only turns. + default: return nil, fmt.Errorf("openailegacy: assistant content: %w", err) } - contentParts = cp } if jsonpresence.IsPresentNonNullJSON(toolCalls) { if err := frontendlimits.Bytes("tool_calls", len(toolCalls), frontendlimits.MaxRawJSONPayload); err != nil { @@ -257,7 +271,7 @@ func parseAssistantParts(content, toolCalls, functionCall json.RawMessage) ([]li parts = append(parts, lipapi.Part{Kind: lipapi.PartJSON, Content: functionCallPart}) } if len(parts) == 0 { - return nil, errors.New("openailegacy: assistant message requires content, tool_calls, or function_call") + return nil, errEmptyAssistantMessage } return parts, nil } @@ -316,7 +330,7 @@ func parseChatContent(raw json.RawMessage) ([]lipapi.Part, error) { } s = strings.TrimSpace(s) if s == "" { - return nil, errors.New("message content string is empty") + return nil, errEmptyChatContent } return []lipapi.Part{lipapi.TextPart(s)}, nil } diff --git a/internal/plugins/frontends/openailegacy/decode_test.go b/internal/plugins/frontends/openailegacy/decode_test.go index 49511e6a..c057b01e 100644 --- a/internal/plugins/frontends/openailegacy/decode_test.go +++ b/internal/plugins/frontends/openailegacy/decode_test.go @@ -263,6 +263,81 @@ func TestDecodeChat_assistantToolCalls(t *testing.T) { } } +// TestDecodeChat_assistantToolCallsEmptyContentString proves an assistant turn carrying only a +// tool call may use an empty content string (not just null), as some clients (e.g. OpenCode) emit. +func TestDecodeChat_assistantToolCallsEmptyContentString(t *testing.T) { + t.Parallel() + body := []byte(`{ + "model": "gpt-4o-mini", + "messages": [{ + "role": "assistant", + "content": "", + "tool_calls": [{"id":"call_1","type":"function","function":{"name":"x","arguments":"{}"}}] + }] +}`) + d, err := openailegacy.DecodeChatRequest(body, openailegacy.DecodeOptions{RouteSelector: "stub:gpt-4o-mini"}) + if err != nil { + t.Fatalf("decode: %v", err) + } + if len(d.Call.Messages) != 1 || len(d.Call.Messages[0].Parts) != 1 { + t.Fatalf("parts: %#v", d.Call.Messages) + } + if d.Call.Messages[0].Parts[0].Kind != lipapi.PartJSON { + t.Fatalf("want PartJSON, got %#v", d.Call.Messages[0].Parts[0]) + } + if err := d.Call.Validate(); err != nil { + t.Fatal(err) + } +} + +// TestDecodeChat_emptyAssistantMessageIsSkipped proves OpenAI-compatible +// clients may include empty assistant history entries after compaction. +func TestDecodeChat_emptyAssistantMessageIsSkipped(t *testing.T) { + t.Parallel() + body := []byte(`{ + "model": "gpt-4o-mini", + "messages": [ + {"role":"user","content":"before"}, + {"role":"assistant","content":""}, + {"role":"user","content":"after"} + ] +}`) + d, err := openailegacy.DecodeChatRequest(body, openailegacy.DecodeOptions{RouteSelector: "stub:gpt-4o-mini"}) + if err != nil { + t.Fatalf("decode: %v", err) + } + if got := len(d.Call.Messages); got != 2 { + t.Fatalf("messages = %d, want 2", got) + } + for _, m := range d.Call.Messages { + if m.Role == lipapi.RoleAssistant { + t.Fatalf("empty assistant message was not skipped: %#v", d.Call.Messages) + } + } + if err := d.Call.Validate(); err != nil { + t.Fatal(err) + } +} + +// TestDecodeChat_reasoningEffortFromBody proves the Chat Completions body's reasoning_effort +// field is decoded into call.Options.ReasoningEffort before executor routing merges URI +// params. Route URI params override matching body options later in core execution. +func TestDecodeChat_reasoningEffortFromBody(t *testing.T) { + t.Parallel() + body := []byte(`{ + "model": "openai-codex:openai/gpt-5.4-mini?reasoning_effort=xhigh", + "reasoning_effort": "medium", + "messages": [{"role":"user","content":"hi"}] +}`) + d, err := openailegacy.DecodeChatRequest(body, openailegacy.DecodeOptions{RouteSelector: "openai-codex:openai/gpt-5.4-mini?reasoning_effort=xhigh"}) + if err != nil { + t.Fatalf("decode: %v", err) + } + if d.Call.Options.ReasoningEffort != "medium" { + t.Fatalf("ReasoningEffort %q, want %q (body field must be decoded)", d.Call.Options.ReasoningEffort, "medium") + } +} + func TestDecodeChat_assistantFunctionCallLegacy(t *testing.T) { t.Parallel() body := []byte(`{ diff --git a/internal/plugins/frontends/openailegacy/handler.go b/internal/plugins/frontends/openailegacy/handler.go index 9dc9debd..a4d0f762 100644 --- a/internal/plugins/frontends/openailegacy/handler.go +++ b/internal/plugins/frontends/openailegacy/handler.go @@ -5,6 +5,7 @@ import ( "log/slog" "net/http" "strings" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/decodeqos" @@ -13,6 +14,7 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/jsonguard" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/reqbody" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/streamdebug" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk/traffic" @@ -28,12 +30,14 @@ type Handler struct { Exec lipsdk.ExecutorView // DefaultRouteSelector is used when HeaderRouteSelector is absent. DefaultRouteSelector string - MaxRequestBodyBytes int64 - Log *slog.Logger - TrafficPorts traffic.PortBundle - DecodeLimiter *decodeqos.Limiter - PreRequestKeepalive lipsdk.FrontendKeepaliveConfig - Config Config + // RoutePrefixes are backend route-selector prefixes accepted from body model. + RoutePrefixes routeselect.PrefixSet + MaxRequestBodyBytes int64 + Log *slog.Logger + TrafficPorts traffic.PortBundle + DecodeLimiter *decodeqos.Limiter + PreRequestKeepalive lipsdk.FrontendKeepaliveConfig + Config Config } func (h *Handler) maxBodyLimit() int64 { @@ -110,14 +114,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if sel == "" { - sel = routeselect.FromModelOrDefault(body, h.DefaultRouteSelector) + sel = h.RoutePrefixes.FromModelOrDefault(body, h.DefaultRouteSelector) } decoded, err := DecodeChatRequest(body, DecodeOptions{RouteSelector: sel, Headers: r.Header}) releaseDecode() if err != nil { - if h.Log != nil { - diag.LogError(ctx, h.Log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + log := h.Log + if log == nil { + log = slog.Default() } + diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON", "invalid_request_error", "")) return } @@ -137,6 +144,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { SessionID: call.Session.CorrelationID(), }, "http", ct, body) + streamdebug.LogCall(ctx, h.Log, ID, call, decoded.Stream, len(body), sel) + executeStart := time.Now() es, err := h.execute(ctx, w, call, decoded.Stream) if err != nil { out := execerr.ClassifyExecute(err) @@ -157,8 +166,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + streamdebug.LogExecuteOpened(ctx, h.Log, ID, call, executeStart) ctx = diag.EnsureCallDiag(ctx, traceID, call.Session.ALegID) + es = streamdebug.Wrap(ctx, h.Log, ID, call, es, executeStart) opts := EncodeOptions{ CompletionID: "chatcmpl_" + diag.StableCallToken(call), diff --git a/internal/plugins/frontends/openailegacy/handler_gzip_test.go b/internal/plugins/frontends/openailegacy/handler_gzip_test.go new file mode 100644 index 00000000..cda431d4 --- /dev/null +++ b/internal/plugins/frontends/openailegacy/handler_gzip_test.go @@ -0,0 +1,44 @@ +package openailegacy_test + +import ( + "bytes" + "compress/gzip" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openailegacy" +) + +// TestHandler_acceptsGzipContentEncoding proves the frontend transparently decompresses +// Content-Encoding: gzip request bodies before preflight/decode, so gzip-compressed clients +// (e.g. some AI coding agents) are not bounced with "invalid request JSON". +func TestHandler_acceptsGzipContentEncoding(t *testing.T) { + t.Parallel() + + plain := readGolden(t, "create_text_nonstream.json") + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(plain); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + + exec := &recordingExecutor{} + h := &openailegacy.Handler{Exec: exec, DefaultRouteSelector: "stub:gpt-4o-mini"} + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(buf.Bytes())) + req.Header.Set("Content-Encoding", "gzip") + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + + h.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: %d body: %s", rr.Code, rr.Body.String()) + } + if !exec.called { + t.Fatal("executor was not called for gzip-encoded body") + } +} diff --git a/internal/plugins/frontends/openailegacy/integration_test.go b/internal/plugins/frontends/openailegacy/integration_test.go index 8c13e31d..2d2edf68 100644 --- a/internal/plugins/frontends/openailegacy/integration_test.go +++ b/internal/plugins/frontends/openailegacy/integration_test.go @@ -12,6 +12,7 @@ import ( "testing" front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openailegacy" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" refcli "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/openaichat" "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/refclienttest" "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" @@ -337,7 +338,7 @@ func TestIntegration_routeHeaderOverridesDefault(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/chat/completions", h) srv := httptest.NewServer(mux) @@ -373,7 +374,7 @@ func TestIntegration_modelRouteSelectorUsedWhenHeaderAbsent(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/chat/completions", h) srv := httptest.NewServer(mux) diff --git a/internal/plugins/frontends/openairesponses/codex_body_routing_test.go b/internal/plugins/frontends/openairesponses/codex_body_routing_test.go new file mode 100644 index 00000000..6c3dd3ae --- /dev/null +++ b/internal/plugins/frontends/openairesponses/codex_body_routing_test.go @@ -0,0 +1,84 @@ +package openairesponses_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/b2bua" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/execbackend" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/hooks" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" + "github.com/matdev83/go-llm-interactive-proxy/internal/core/runtime" + front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openairesponses" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" + "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// TestIntegration_openaiCodexBodyModelOverridesRouteWithReasoningEffort proves that a +// client can manually route to the openai-codex backend with an arbitrary model and URI +// params by putting "openai-codex:?reasoning_effort=" in the request body +// model field. The selector must override the configured default route, and the +// reasoning_effort URI param must be converted into call.Options.ReasoningEffort (the +// canonical data structure the openai-codex backend reads to set reasoning effort). +func TestIntegration_openaiCodexBodyModelOverridesRouteWithReasoningEffort(t *testing.T) { + t.Parallel() + st, err := b2bua.NewMemoryStore(b2bua.MemoryStoreOptions{}) + if err != nil { + t.Fatal(err) + } + var captured lipapi.Call + ex := &runtime.Executor{ + Store: st, + Bus: hooks.New(hooks.Config{}), + Rand: routing.NewSeededRng(42), + Backends: map[string]execbackend.Backend{ + "openai-codex": { + Caps: lipapi.NewBackendCaps(lipapi.CapabilityStreaming), + Open: func(_ context.Context, call lipapi.Call, _ routing.AttemptCandidate) (lipapi.ManagedEventStream, error) { + captured = call + return lipapi.NewFixedEventStream([]lipapi.Event{ + {Kind: lipapi.EventResponseStarted}, + {Kind: lipapi.EventMessageStarted}, + {Kind: lipapi.EventTextDelta, Delta: "ok"}, + {Kind: lipapi.EventResponseFinished}, + }), nil + }, + }, + }, + } + testkit.WireConformanceExecutorSecureSession(t, ex) + + h := &front.Handler{Exec: ex, DefaultRouteSelector: "openai-codex:gpt-5.5", RoutePrefixes: routeselect.NewPrefixSet([]string{"openai-codex"})} + mux := http.NewServeMux() + mux.Handle("/v1/responses", h) + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, srv.URL+"/v1/responses", + strings.NewReader(`{"model":"openai-codex:gpt-5.5?reasoning_effort=low","input":"x"}`)) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", "application/json") + res, err := testkit.IntegrationHTTPClient(nil).Do(req) + if err != nil { + t.Fatal(err) + } + defer func() { _ = res.Body.Close() }() + if res.StatusCode != http.StatusOK { + b, _ := io.ReadAll(res.Body) + t.Fatalf("status %d body %s", res.StatusCode, string(b)) + } + + if want := "openai-codex:gpt-5.5?reasoning_effort=low"; captured.Route.Selector != want { + t.Fatalf("route selector %q, want %q", captured.Route.Selector, want) + } + if captured.Options.ReasoningEffort != "low" { + t.Fatalf("call.Options.ReasoningEffort %q, want %q", captured.Options.ReasoningEffort, "low") + } +} diff --git a/internal/plugins/frontends/openairesponses/handler.go b/internal/plugins/frontends/openairesponses/handler.go index e77164a0..f020c734 100644 --- a/internal/plugins/frontends/openairesponses/handler.go +++ b/internal/plugins/frontends/openairesponses/handler.go @@ -8,6 +8,7 @@ import ( "log/slog" "net/http" "strings" + "time" "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" "github.com/matdev83/go-llm-interactive-proxy/internal/core/securesession/domain" @@ -18,6 +19,7 @@ import ( "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/reqbody" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/sessionwire" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/streamdebug" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipsdk/traffic" @@ -34,6 +36,8 @@ type Handler struct { Exec lipsdk.ExecutorView // DefaultRouteSelector is used when HeaderRouteSelector is absent. DefaultRouteSelector string + // RoutePrefixes are backend route-selector prefixes accepted from body model. + RoutePrefixes routeselect.PrefixSet // MaxRequestBodyBytes caps the request body; zero uses reqbody.DefaultMaxBytes. MaxRequestBodyBytes int64 Log *slog.Logger @@ -154,14 +158,17 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } if sel == "" { - sel = routeselect.FromModelOrDefault(body, h.DefaultRouteSelector) + sel = h.RoutePrefixes.FromModelOrDefault(body, h.DefaultRouteSelector) } decoded, err := DecodeCreateRequest(body, DecodeOptions{RouteSelector: sel, Headers: r.Header}) releaseDecode() if err != nil { - if h.Log != nil { - diag.LogError(ctx, h.Log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + log := h.Log + if log == nil { + log = slog.Default() } + diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) + streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON( w, http.StatusBadRequest, @@ -193,6 +200,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { SessionID: call.Session.CorrelationID(), }, "http", ct, body) + streamdebug.LogCall(ctx, h.Log, ID, call, decoded.Stream, len(body), sel) + executeStart := time.Now() es, err := h.execute(ctx, w, call, decoded.Stream) if err != nil { out := execerr.ClassifyExecute(err) @@ -231,8 +240,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } + streamdebug.LogExecuteOpened(ctx, h.Log, ID, call, executeStart) ctx = diag.EnsureCallDiag(ctx, traceID, call.Session.ALegID) + es = streamdebug.Wrap(ctx, h.Log, ID, call, es, executeStart) opts := EncodeOptions{ ResponseID: responseIDForCall(call), diff --git a/internal/plugins/frontends/openairesponses/integration_test.go b/internal/plugins/frontends/openairesponses/integration_test.go index 56bfacaf..343ba45c 100644 --- a/internal/plugins/frontends/openairesponses/integration_test.go +++ b/internal/plugins/frontends/openairesponses/integration_test.go @@ -12,6 +12,7 @@ import ( "testing" front "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/openairesponses" + "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/routeselect" refcli "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/openairesponses" "github.com/matdev83/go-llm-interactive-proxy/internal/refclient/refclienttest" "github.com/matdev83/go-llm-interactive-proxy/internal/testkit" @@ -365,7 +366,7 @@ func TestIntegration_routeHeaderOverridesDefault(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/responses", h) srv := httptest.NewServer(mux) @@ -400,7 +401,7 @@ func TestIntegration_modelRouteSelectorUsedWhenHeaderAbsent(t *testing.T) { t.Parallel() var capture sync.Map ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &capture) - h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route"} + h := &front.Handler{Exec: ex, DefaultRouteSelector: "stub:default-route", RoutePrefixes: routeselect.NewPrefixSet([]string{"stub"})} mux := http.NewServeMux() mux.Handle("/v1/responses", h) srv := httptest.NewServer(mux) diff --git a/internal/plugins/frontends/reqbody/body.go b/internal/plugins/frontends/reqbody/body.go index 3e529909..5c7f7500 100644 --- a/internal/plugins/frontends/reqbody/body.go +++ b/internal/plugins/frontends/reqbody/body.go @@ -2,10 +2,12 @@ package reqbody import ( + "compress/gzip" "errors" "fmt" "io" "net/http" + "strings" ) // DefaultMaxBytes is the maximum request body size when no explicit limit is set. @@ -13,14 +15,42 @@ const DefaultMaxBytes int64 = 8 << 20 // ReadAll reads r.Body using http.MaxBytesReader. On limit exceeded it returns a non-nil err // for which TooLarge returns true; callers should respond with HTTP 413 without treating it as JSON parse failure. +// +// When the request advertises Content-Encoding: gzip, the body is transparently decompressed +// and the byte limit is applied to the decompressed size (mitigating decompression bombs). func ReadAll(w http.ResponseWriter, r *http.Request, maxBytes int64) (data []byte, err error) { if maxBytes <= 0 { maxBytes = DefaultMaxBytes } - lr := http.MaxBytesReader(w, r.Body, maxBytes) + src := r.Body + var gzr *gzip.Reader + if isGzipEncoded(r) { + gzr, err = gzip.NewReader(src) + if err != nil { + if cerr := src.Close(); cerr != nil { + err = errors.Join(err, fmt.Errorf("reqbody: close body reader: %w", cerr)) + } + return nil, err + } + } + reader := src + if gzr != nil { + reader = gzr + } + lr := http.MaxBytesReader(w, reader, maxBytes) defer func() { + var cerrs []error if cerr := lr.Close(); cerr != nil { - closeErr := fmt.Errorf("reqbody: close body reader: %w", cerr) + cerrs = append(cerrs, fmt.Errorf("reqbody: close body reader: %w", cerr)) + } + // gzip.Reader.Close does not close the underlying body, so close src explicitly. + if gzr != nil { + if cerr := src.Close(); cerr != nil { + cerrs = append(cerrs, fmt.Errorf("reqbody: close gzip source body: %w", cerr)) + } + } + if len(cerrs) > 0 { + closeErr := errors.Join(cerrs...) if err != nil { err = errors.Join(err, closeErr) } else { @@ -32,6 +62,19 @@ func ReadAll(w http.ResponseWriter, r *http.Request, maxBytes int64) (data []byt return data, err } +func isGzipEncoded(r *http.Request) bool { + h := strings.TrimSpace(r.Header.Get("Content-Encoding")) + if h == "" { + return false + } + for part := range strings.SplitSeq(h, ",") { + if strings.EqualFold(strings.TrimSpace(part), "gzip") { + return true + } + } + return false +} + // TooLarge reports whether err is from exceeding MaxBytesReader's limit. // It uses errors.As so any error in the chain that unwraps to *http.MaxBytesError matches. func TooLarge(err error) bool { diff --git a/internal/plugins/frontends/reqbody/body_test.go b/internal/plugins/frontends/reqbody/body_test.go index 555afc80..44667a92 100644 --- a/internal/plugins/frontends/reqbody/body_test.go +++ b/internal/plugins/frontends/reqbody/body_test.go @@ -2,11 +2,13 @@ package reqbody_test import ( "bytes" + "compress/gzip" "errors" "fmt" "io" "net/http" "net/http/httptest" + "strings" "testing" "github.com/matdev83/go-llm-interactive-proxy/internal/plugins/frontends/reqbody" @@ -103,3 +105,67 @@ func TestReadAll_readErrorJoinedWithCloseError(t *testing.T) { t.Fatalf("expected close error joined in chain, got %v", err) } } + +func gzipBytes(t *testing.T, payload []byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(payload); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +func TestReadAll_decompressesGzipContentEncoding(t *testing.T) { + t.Parallel() + payload := []byte(`{"model":"openai-codex:gpt-5.4-mini","messages":[{"role":"user","content":"hi"}]}`) + gz := gzipBytes(t, payload) + r := httptest.NewRequest("POST", "/", bytes.NewReader(gz)) + r.Header.Set("Content-Encoding", "gzip") + r.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + b, err := reqbody.ReadAll(w, r, int64(len(payload)+64)) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !bytes.Equal(b, payload) { + t.Fatalf("expected decompressed JSON payload, got %q", string(b)) + } +} + +func TestReadAll_gzipTooLargeAfterDecompression(t *testing.T) { + t.Parallel() + payload := bytes.Repeat([]byte("a"), 1000) + gz := gzipBytes(t, payload) + r := httptest.NewRequest("POST", "/", bytes.NewReader(gz)) + r.Header.Set("Content-Encoding", "gzip") + w := httptest.NewRecorder() + _, err := reqbody.ReadAll(w, r, 100) + if err == nil || !reqbody.TooLarge(err) { + t.Fatalf("expected too-large error after decompression, got %v", err) + } +} + +func TestReadAll_gzipSourceCloseErrorIsDistinguishable(t *testing.T) { + t.Parallel() + closeErr := errors.New("source close failed") + payload := []byte(`{"x":1}`) + body := &errCloseBody{r: bytes.NewReader(gzipBytes(t, payload)), err: closeErr} + r := httptest.NewRequest("POST", "/", body) + r.Header.Set("Content-Encoding", "gzip") + w := httptest.NewRecorder() + + _, err := reqbody.ReadAll(w, r, int64(len(payload)+64)) + if err == nil { + t.Fatal("expected gzip source close error") + } + if !errors.Is(err, closeErr) { + t.Fatalf("expected close error in chain, got %v", err) + } + if !strings.Contains(err.Error(), "reqbody: close gzip source body") { + t.Fatalf("close error label = %q, want gzip source body", err.Error()) + } +} diff --git a/internal/plugins/frontends/routeselect/routeselect.go b/internal/plugins/frontends/routeselect/routeselect.go index 370b59b4..30fdf96d 100644 --- a/internal/plugins/frontends/routeselect/routeselect.go +++ b/internal/plugins/frontends/routeselect/routeselect.go @@ -5,32 +5,28 @@ package routeselect import ( "encoding/json" "strings" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/routing" ) -var inlineRoutePrefixes = map[string]struct{}{ - "acp": {}, - "anthropic": {}, - "bedrock": {}, - "gemini": {}, - "llamacpp": {}, - "local-stub": {}, - "nvidia": {}, - "ollama": {}, - "ollama-cloud": {}, - "openai-legacy": {}, - "openai-responses": {}, - "openrouter": {}, - // Test frontends use stub route selectors with executor stubs outside the production backend registry. - "stub": {}, +type PrefixSet map[string]struct{} + +func NewPrefixSet(prefixes []string) PrefixSet { + filtered := routing.FilterRoutePrefixes(prefixes) + out := make(PrefixSet, len(filtered)) + for _, prefix := range filtered { + out[prefix] = struct{}{} + } + return out } // InlineOrDefault returns model when it has a known backend prefix before the colon delimiter. // Otherwise it returns defaultRoute with surrounding whitespace removed. -func InlineOrDefault(model, defaultRoute string) string { +func (p PrefixSet) InlineOrDefault(model, defaultRoute string) string { model = strings.TrimSpace(model) prefix, _, ok := strings.Cut(model, ":") if ok { - if _, known := inlineRoutePrefixes[strings.TrimSpace(prefix)]; known { + if _, known := p[strings.TrimSpace(prefix)]; known { return model } } @@ -39,12 +35,12 @@ func InlineOrDefault(model, defaultRoute string) string { // FromModelOrDefault parses body for a model field and returns it when it carries a known inline route prefix. // If decoding fails or the model has no known prefix, it returns defaultRoute with surrounding whitespace removed. -func FromModelOrDefault(body []byte, defaultRoute string) string { +func (p PrefixSet) FromModelOrDefault(body []byte, defaultRoute string) string { var req struct { Model string `json:"model"` } if err := json.Unmarshal(body, &req); err == nil { - return InlineOrDefault(req.Model, defaultRoute) + return p.InlineOrDefault(req.Model, defaultRoute) } return strings.TrimSpace(defaultRoute) } diff --git a/internal/plugins/frontends/routeselect/routeselect_test.go b/internal/plugins/frontends/routeselect/routeselect_test.go index 8e9c3517..dc048b01 100644 --- a/internal/plugins/frontends/routeselect/routeselect_test.go +++ b/internal/plugins/frontends/routeselect/routeselect_test.go @@ -6,6 +6,18 @@ import ( func TestInlineOrDefault(t *testing.T) { t.Parallel() + prefixes := NewPrefixSet([]string{ + "anthropic", + "huggingface", + "llamacpp", + "lmstudio", + "ollama-cloud", + "openai-codex", + "opencode-go", + "opencode-zen", + "stub", + "vllm", + }) tests := []struct { name string @@ -61,13 +73,69 @@ func TestInlineOrDefault(t *testing.T) { defaultRoute: " openai-legacy:gpt-4o ", want: "anthropic:claude-3-5-sonnet", }, + // openai-codex: arbitrary model + optional URI params must override the default route. + { + name: "openai-codex prefix with reasoning_effort param", + model: "openai-codex:gpt-5.5?reasoning_effort=low", + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.5?reasoning_effort=low", + }, + { + name: "openai-codex prefix gpt-5.4", + model: "openai-codex:gpt-5.4", + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.4", + }, + { + name: "openai-codex prefix gpt-5.4-mini", + model: "openai-codex:gpt-5.4-mini", + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.4-mini", + }, + { + name: "openai-codex prefix arbitrary model not in static inventory", + model: "openai-codex:gpt-5.3-codex-spark", + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.3-codex-spark", + }, + // Other standard backends missing from the allowlist must also route from the body model. + { + name: "opencode-go prefix", + model: "opencode-go:zen-go-1", + defaultRoute: "openai-codex:gpt-5.5", + want: "opencode-go:zen-go-1", + }, + { + name: "opencode-zen prefix", + model: "opencode-zen:zen-1", + defaultRoute: "openai-codex:gpt-5.5", + want: "opencode-zen:zen-1", + }, + { + name: "huggingface prefix", + model: "huggingface:Qwen/Qwen3", + defaultRoute: "openai-codex:gpt-5.5", + want: "huggingface:Qwen/Qwen3", + }, + { + name: "vllm prefix", + model: "vllm:llama-3", + defaultRoute: "openai-codex:gpt-5.5", + want: "vllm:llama-3", + }, + { + name: "lmstudio prefix", + model: "lmstudio:local-model", + defaultRoute: "openai-codex:gpt-5.5", + want: "lmstudio:local-model", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := InlineOrDefault(tt.model, tt.defaultRoute) + got := prefixes.InlineOrDefault(tt.model, tt.defaultRoute) if got != tt.want { t.Errorf("InlineOrDefault(%q, %q) = %q, want %q", tt.model, tt.defaultRoute, got, tt.want) } @@ -77,6 +145,7 @@ func TestInlineOrDefault(t *testing.T) { func TestFromModelOrDefault(t *testing.T) { t.Parallel() + prefixes := NewPrefixSet([]string{"anthropic", "openai-codex"}) tests := []struct { name string @@ -114,13 +183,26 @@ func TestFromModelOrDefault(t *testing.T) { defaultRoute: "openai-responses:gpt-4o", want: "openai-responses:gpt-4o", }, + // openai-codex: the body model (with optional URI params) must override the default route. + { + name: "openai-codex inline model with reasoning_effort param", + body: `{"model": "openai-codex:gpt-5.5?reasoning_effort=low"}`, + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.5?reasoning_effort=low", + }, + { + name: "openai-codex inline model no params", + body: `{"model": "openai-codex:gpt-5.4"}`, + defaultRoute: "openai-codex:gpt-5.5", + want: "openai-codex:gpt-5.4", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got := FromModelOrDefault([]byte(tt.body), tt.defaultRoute) + got := prefixes.FromModelOrDefault([]byte(tt.body), tt.defaultRoute) if got != tt.want { t.Errorf("FromModelOrDefault(%q, %q) = %q, want %q", tt.body, tt.defaultRoute, got, tt.want) } diff --git a/internal/plugins/frontends/streamdebug/streamdebug.go b/internal/plugins/frontends/streamdebug/streamdebug.go new file mode 100644 index 00000000..6855b5a2 --- /dev/null +++ b/internal/plugins/frontends/streamdebug/streamdebug.go @@ -0,0 +1,335 @@ +package streamdebug + +import ( + "context" + "encoding/json" + "errors" + "io" + "log/slog" + "sort" + "strings" + "sync" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// Enabled reports whether verbose local turn diagnostics are enabled. +func Enabled() bool { + return diag.DebugTurnsEnabled() +} + +// LogCall records a compact canonical request shape without logging prompt text. +func LogCall(ctx context.Context, log *slog.Logger, frontend string, call *lipapi.Call, stream bool, bodyBytes int, selector string) { + if !Enabled() || call == nil { + return + } + s := summarizeCall(call) + diag.LoggerOrDefault(log).DebugContext(ctx, "lip.debug.frontend_call", + "frontend", frontend, + "call_id", call.ID, + "trace_id", diag.StableCallID(call), + "a_leg_id", strings.TrimSpace(call.Session.ALegID), + "operation", string(call.Invocation.Operation), + "route_selector", selector, + "stream", stream, + "body_bytes", bodyBytes, + "messages", len(call.Messages), + "instructions", len(call.Instructions), + "tools", len(call.Tools), + "role_counts", strings.Join(s.roleCounts, ","), + "part_counts", strings.Join(s.partCounts, ","), + "tool_result_ids", strings.Join(s.toolResultIDs, ","), + "assistant_tool_call_ids", strings.Join(s.assistantToolCallIDs, ","), + "reasoning_effort", call.Options.ReasoningEffort, + "has_max_output_tokens", call.Options.MaxOutputTokens != nil, + "has_temperature", call.Options.Temperature != nil, + "has_top_p", call.Options.TopP != nil, + ) +} + +// LogDecodeFailure records why a frontend rejected a request before canonical call creation. +func LogDecodeFailure(ctx context.Context, log *slog.Logger, frontend string, body []byte, err error) { + if !Enabled() { + return + } + summary := summarizeBody(body) + diag.LoggerOrDefault(log).DebugContext(ctx, "lip.debug.frontend_decode_failed", + "frontend", frontend, + "body_bytes", len(body), + "json_valid", summary.valid, + "top_keys", strings.Join(summary.keys, ","), + "model", summary.model, + "messages", summary.messages, + "input_items", summary.inputItems, + "tools_present", summary.toolsPresent, + "error", errString(err), + ) +} + +// LogExecuteOpened records time spent before the executor returns an event stream. +func LogExecuteOpened(ctx context.Context, log *slog.Logger, frontend string, call *lipapi.Call, start time.Time) { + if !Enabled() || call == nil { + return + } + diag.LoggerOrDefault(log).DebugContext(ctx, "lip.debug.frontend_execute_opened", + "frontend", frontend, + "call_id", call.ID, + "trace_id", diag.StableCallID(call), + "a_leg_id", strings.TrimSpace(call.Session.ALegID), + "duration_ms", time.Since(start).Milliseconds(), + ) +} + +// Wrap logs stream progress and terminal state while preserving EventStream semantics. +func Wrap(_ context.Context, log *slog.Logger, frontend string, call *lipapi.Call, es lipapi.EventStream, start time.Time) lipapi.EventStream { + if !Enabled() || es == nil || call == nil { + return es + } + return &stream{ + log: diag.LoggerOrDefault(log), + frontend: frontend, + call: call, + inner: es, + start: start, + } +} + +var _ lipapi.EventStream = (*stream)(nil) + +type stream struct { + mu sync.Mutex + log *slog.Logger + frontend string + call *lipapi.Call + inner lipapi.EventStream + start time.Time + count int + kindCounts map[string]int + firstTextMs int64 + firstReasonMs int64 + firstLogged bool + terminalLogged bool +} + +func (s *stream) Recv(ctx context.Context) (lipapi.Event, error) { + ev, err := s.inner.Recv(ctx) + if err != nil { + s.logTerminal(ctx, err) + return ev, err + } + s.mu.Lock() + defer s.mu.Unlock() + s.count++ + if s.kindCounts == nil { + s.kindCounts = map[string]int{} + } + s.kindCounts[string(ev.Kind)]++ + if ev.Kind == lipapi.EventTextDelta && s.firstTextMs == 0 { + s.firstTextMs = time.Since(s.start).Milliseconds() + s.logContentFirst(ctx, ev.Kind, ev.Delta, "") + } + if ev.Kind == lipapi.EventReasoningDelta && s.firstReasonMs == 0 { + s.firstReasonMs = time.Since(s.start).Milliseconds() + s.logContentFirst(ctx, ev.Kind, ev.Delta, "") + } + if !s.firstLogged { + s.firstLogged = true + s.log.DebugContext(ctx, "lip.debug.stream_first_event", + "frontend", s.frontend, + "call_id", s.call.ID, + "trace_id", diag.StableCallID(s.call), + "a_leg_id", strings.TrimSpace(s.call.Session.ALegID), + "event_kind", string(ev.Kind), + "duration_ms", time.Since(s.start).Milliseconds(), + ) + } + if shouldLogEvent(ev.Kind) { + s.log.DebugContext(ctx, "lip.debug.stream_event", + "frontend", s.frontend, + "call_id", s.call.ID, + "trace_id", diag.StableCallID(s.call), + "event_index", s.count, + "event_kind", string(ev.Kind), + "tool_call_id", ev.ToolCallID, + "tool_name", ev.ToolName, + "finish_reason", ev.FinishReason, + ) + } + return ev, nil +} + +func (s *stream) Close() error { + err := s.inner.Close() + s.logTerminal(context.Background(), err) + return err +} + +func (s *stream) logTerminal(ctx context.Context, err error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.terminalLogged { + return + } + s.terminalLogged = true + status := "closed" + switch { + case errors.Is(err, io.EOF): + status = "eof" + case err != nil: + status = "error" + } + s.log.DebugContext(ctx, "lip.debug.stream_terminal", + "frontend", s.frontend, + "call_id", s.call.ID, + "trace_id", diag.StableCallID(s.call), + "a_leg_id", strings.TrimSpace(s.call.Session.ALegID), + "status", status, + "error", errString(err), + "events", s.count, + "event_counts", strings.Join(diag.StableCounts(s.kindCounts), ","), + "first_text_ms", s.firstTextMs, + "first_reasoning_ms", s.firstReasonMs, + "duration_ms", time.Since(s.start).Milliseconds(), + ) +} + +func (s *stream) logContentFirst(ctx context.Context, kind lipapi.EventKind, delta, detail string) { + s.log.DebugContext(ctx, "lip.debug.stream_first_content_event", + "frontend", s.frontend, + "call_id", s.call.ID, + "trace_id", diag.StableCallID(s.call), + "a_leg_id", strings.TrimSpace(s.call.Session.ALegID), + "event_kind", string(kind), + "delta_bytes", len(delta), + "detail", detail, + "duration_ms", time.Since(s.start).Milliseconds(), + ) +} + +func shouldLogEvent(kind lipapi.EventKind) bool { + switch kind { + case lipapi.EventToolCallStarted, lipapi.EventToolCallFinished, lipapi.EventError, lipapi.EventResponseFinished: + return true + default: + return false + } +} + +func errString(err error) string { + if err == nil || errors.Is(err, io.EOF) { + return "" + } + return err.Error() +} + +type callSummary struct { + roleCounts []string + partCounts []string + toolResultIDs []string + assistantToolCallIDs []string +} + +type bodySummary struct { + valid bool + keys []string + model string + messages int + inputItems int + toolsPresent bool +} + +func summarizeBody(body []byte) bodySummary { + out := bodySummary{valid: json.Valid(body)} + var top map[string]json.RawMessage + if json.Unmarshal(body, &top) != nil { + return out + } + keys := make([]string, 0, len(top)) + for key := range top { + keys = append(keys, key) + } + out.keys = stableStrings(keys) + if raw := top["model"]; len(raw) > 0 { + _ = json.Unmarshal(raw, &out.model) + } + if raw := top["messages"]; len(raw) > 0 { + var msgs []json.RawMessage + if json.Unmarshal(raw, &msgs) == nil { + out.messages = len(msgs) + } + } + if raw := top["input"]; len(raw) > 0 { + var items []json.RawMessage + if json.Unmarshal(raw, &items) == nil { + out.inputItems = len(items) + } else { + var text string + if json.Unmarshal(raw, &text) == nil && text != "" { + out.inputItems = 1 + } + } + } + _, out.toolsPresent = top["tools"] + return out +} + +func summarizeCall(call *lipapi.Call) callSummary { + roleCounts := map[string]int{} + partCounts := map[string]int{} + var toolResultIDs []string + var assistantToolCallIDs []string + for _, msg := range call.Messages { + roleCounts[string(msg.Role)]++ + for _, part := range msg.Parts { + partCounts[string(part.Kind)]++ + switch { + case part.Kind == lipapi.PartToolResult: + toolResultIDs = diag.AppendLimited(toolResultIDs, part.ToolCallID, 12) + case msg.Role == lipapi.RoleAssistant && part.Kind == lipapi.PartJSON: + assistantToolCallIDs = diag.AppendLimited(assistantToolCallIDs, assistantJSONCallID(part.Content), 12) + } + } + } + return callSummary{ + roleCounts: diag.StableCounts(roleCounts), + partCounts: diag.StableCounts(partCounts), + toolResultIDs: toolResultIDs, + assistantToolCallIDs: assistantToolCallIDs, + } +} + +func stableStrings(values []string) []string { + out := append([]string(nil), values...) + sort.Strings(out) + return out +} + +func assistantJSONCallID(raw []byte) string { + const maxProbe = 4096 + if len(raw) > maxProbe { + raw = raw[:maxProbe] + } + body := string(raw) + for _, key := range []string{`"call_id"`, `"id"`} { + _, rest, ok := strings.Cut(body, key) + if !ok { + continue + } + colon := strings.IndexByte(rest, ':') + if colon < 0 { + continue + } + rest = strings.TrimSpace(rest[colon+1:]) + if !strings.HasPrefix(rest, `"`) { + continue + } + rest = rest[1:] + end := strings.IndexByte(rest, '"') + if end > 0 { + return rest[:end] + } + } + return "" +} diff --git a/internal/plugins/frontends/streamdebug/streamdebug_test.go b/internal/plugins/frontends/streamdebug/streamdebug_test.go new file mode 100644 index 00000000..eb1216e7 --- /dev/null +++ b/internal/plugins/frontends/streamdebug/streamdebug_test.go @@ -0,0 +1,44 @@ +package streamdebug + +import ( + "context" + "io" + "log/slog" + "testing" + "time" + + "github.com/matdev83/go-llm-interactive-proxy/internal/core/diag" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func TestEnabledUsesDiagGate(t *testing.T) { + t.Parallel() + if Enabled() != diag.DebugTurnsEnabled() { + t.Fatalf("Enabled() = %v, want diag gate %v", Enabled(), diag.DebugTurnsEnabled()) + } +} + +func TestWrapFollowsDebugGate(t *testing.T) { + t.Parallel() + inner := &testStream{} + wrapped := Wrap(context.Background(), slog.New(slog.NewTextHandler(io.Discard, nil)), "test", &lipapi.Call{ID: "call"}, inner, time.Now()) + if diag.DebugTurnsEnabled() { + if wrapped == inner { + t.Fatal("Wrap enabled returned original stream, want debug wrapper") + } + return + } + if wrapped != inner { + t.Fatalf("Wrap disabled returned %T, want original stream", wrapped) + } +} + +type testStream struct{} + +func (*testStream) Recv(context.Context) (lipapi.Event, error) { + return lipapi.Event{}, io.EOF +} + +func (*testStream) Close() error { + return nil +} diff --git a/internal/refbackend/openaicodex/server.go b/internal/refbackend/openaicodex/server.go index 7c12f1aa..560b611c 100644 --- a/internal/refbackend/openaicodex/server.go +++ b/internal/refbackend/openaicodex/server.go @@ -8,10 +8,37 @@ import ( "net/http" "strings" "sync" + + "github.com/gorilla/websocket" ) const maxBodyBytes = 10 << 20 +var wsUpgrader = websocket.Upgrader{ + CheckOrigin: func(*http.Request) bool { return true }, +} + +// WSFailureMode selects a one-shot WebSocket failure behavior for the emulator. +type WSFailureMode int + +const ( + WSFailureNone WSFailureMode = iota + // WSFailurePolicyCloseBeforeEvent closes immediately after upgrade before any event frame. + WSFailurePolicyCloseBeforeEvent + // WSFailureNormalCloseBeforeEvent sends CloseNormalClosure immediately after upgrade. + WSFailureNormalCloseBeforeEvent + // WSFailureNoCanonicalFirstFrame sends one mappable but non-canonical event, then closes normally. + WSFailureNoCanonicalFirstFrame + // WSFailureMalformedFirstFrame sends invalid JSON as the first event frame. + WSFailureMalformedFirstFrame + // WSFailureStall consumes response.create and never sends an event frame. + WSFailureStall + // WSFailureAfterFirstEvent sends the first canonical event frame, then drops the connection. + WSFailureAfterFirstEvent + // WSFailureStallAfterFirstEvent sends the first canonical event frame, then stalls. + WSFailureStallAfterFirstEvent +) + // Config tunes the Codex Responses emulator. type Config struct { Token string @@ -21,11 +48,18 @@ type Config struct { ForcedHTTPStatus int ForcedRetryAfter string ForcedErrorJSON string + // ForcedWSFailure applies one one-shot WebSocket failure mode to the first upgrade. + ForcedWSFailure WSFailureMode + // ForcedWSRejectModel, when set, makes the WebSocket handler send a pre-content + // error event frame and close when the response.create payload names this model. + // Used to exercise reactive gpt-5.5 downgrade on the WebSocket path. + ForcedWSRejectModel string } // CapturedRequest is a snapshot of the latest handled POST /responses request. type CapturedRequest struct { Path string + Transport string Authorization string OpenAIBeta string Originator string @@ -45,6 +79,7 @@ type Server struct { nextForcedStatus int nextForcedRetry string nextForcedErrorJSON string + nextWSFailure WSFailureMode } // New returns an emulator configured from cfg. @@ -55,9 +90,17 @@ func New(cfg Config) *Server { s.nextForcedRetry = cfg.ForcedRetryAfter s.nextForcedErrorJSON = cfg.ForcedErrorJSON } + s.nextWSFailure = cfg.ForcedWSFailure return s } +// ForceNextWSFailure applies mode to the next WebSocket upgrade. +func (s *Server) ForceNextWSFailure(mode WSFailureMode) { + s.mu.Lock() + defer s.mu.Unlock() + s.nextWSFailure = mode +} + // Handler returns the emulator HTTP handler. func (s *Server) Handler() http.Handler { return http.HandlerFunc(s.serve) @@ -75,6 +118,10 @@ func (s *Server) LatestRequest() CapturedRequest { } func (s *Server) serve(w http.ResponseWriter, r *http.Request) { + if websocket.IsWebSocketUpgrade(r) { + s.serveWebSocket(w, r) + return + } if r.Method != http.MethodPost || !strings.HasSuffix(r.URL.Path, "/responses") { http.NotFound(w, r) return @@ -111,17 +158,7 @@ func (s *Server) serve(w http.ResponseWriter, r *http.Request) { } s.mu.Lock() - s.latest = CapturedRequest{ - Path: r.URL.Path, - Authorization: r.Header.Get("Authorization"), - OpenAIBeta: r.Header.Get("OpenAI-Beta"), - Originator: r.Header.Get("originator"), - CodexTaskType: r.Header.Get("Codex-Task-Type"), - ConversationID: r.Header.Get("conversation_id"), - SessionID: r.Header.Get("session_id"), - ChatGPTAccountID: r.Header.Get("chatgpt-account-id"), - Body: maps.Clone(payload), - } + s.latest = captureRequest(r, "https", payload) forcedStatus := s.nextForcedStatus forcedRetry := s.nextForcedRetry forcedJSON := s.nextForcedErrorJSON @@ -147,6 +184,108 @@ func (s *Server) serve(w http.ResponseWriter, r *http.Request) { writeStream(w, s.outputText()) } +func (s *Server) serveWebSocket(w http.ResponseWriter, r *http.Request) { + if s.cfg.Token != "" { + if r.Header.Get("Authorization") != "Bearer "+s.cfg.Token { + http.Error(w, "missing or invalid bearer", http.StatusUnauthorized) + return + } + } + if err := validateCodexHeaders(r); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + respHeader := http.Header{} + if s.cfg.PlanType != "" { + respHeader.Set("x-codex-plan-type", s.cfg.PlanType) + } + if s.cfg.UsagePercent != "" { + respHeader.Set("x-codex-primary-used-percent", s.cfg.UsagePercent) + } + conn, err := wsUpgrader.Upgrade(w, r, respHeader) + if err != nil { + return + } + defer func() { _ = conn.Close() }() + + s.mu.Lock() + forcedWSFailure := s.nextWSFailure + s.nextWSFailure = WSFailureNone + s.mu.Unlock() + switch forcedWSFailure { + case WSFailurePolicyCloseBeforeEvent: + _ = conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "forced ws fail")) + return + case WSFailureNormalCloseBeforeEvent: + _ = conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + } + + _, frame, err := conn.ReadMessage() + if err != nil { + return + } + var payload map[string]any + if len(frame) > 0 { + _ = json.Unmarshal(frame, &payload) + } + if payload == nil { + payload = map[string]any{} + } + if rejectModel := strings.TrimSpace(s.cfg.ForcedWSRejectModel); rejectModel != "" { + if model, _ := payload["model"].(string); model == rejectModel { + _ = conn.WriteMessage(websocket.TextMessage, []byte( + `{"type":"error","error":{"message":"gpt-5.5 is not available on free plan"}}`)) + return + } + } + s.mu.Lock() + s.latest = captureRequest(r, "websocket", payload) + s.mu.Unlock() + + switch forcedWSFailure { + case WSFailureStall: + // Upgrade succeeded and the request frame was consumed, but no event is + // ever sent. Block until the client abandons the read and closes, which + // is how a server that upgrades but never produces a first event looks. + _, _, _ = conn.ReadMessage() + return + case WSFailureNoCanonicalFirstFrame: + _ = conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.unknown_ack"}`)) + _ = conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + return + case WSFailureMalformedFirstFrame: + _ = conn.WriteMessage(websocket.TextMessage, []byte("{not json")) + return + case WSFailureAfterFirstEvent: + // Send only the first canonical event, then drop mid-stream. + _ = conn.WriteMessage(websocket.TextMessage, []byte(codexEventFrames(s.outputText())[0])) + return + case WSFailureStallAfterFirstEvent: + _ = conn.WriteMessage(websocket.TextMessage, []byte(codexEventFrames(s.outputText())[0])) + _, _, _ = conn.ReadMessage() + return + } + + for _, raw := range codexEventFrames(s.outputText()) { + if err := conn.WriteMessage(websocket.TextMessage, []byte(raw)); err != nil { + return + } + } + _ = conn.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +} + +func codexEventFrames(text string) []string { + created := `{"type":"response.created","sequence_number":0,"response":{"id":"resp_codex_ref","object":"response","created_at":1715620000,"status":"in_progress","model":"gpt-5.3-codex-spark"}}` + delta := fmt.Sprintf(`{"type":"response.output_text.delta","sequence_number":1,"item_id":"msg_codex_ref","output_index":0,"content_index":0,"delta":%q}`, text) + completed := fmt.Sprintf(`{"type":"response.completed","sequence_number":2,"response":{"id":"resp_codex_ref","object":"response","created_at":1715620000,"status":"completed","model":"gpt-5.3-codex-spark","output":[{"type":"message","id":"msg_codex_ref","status":"completed","role":"assistant","content":[{"type":"output_text","text":%q}]}],"usage":{"input_tokens":3,"output_tokens":5,"total_tokens":8}}}`, text) + return []string{created, delta, completed} +} + func validateCodexHeaders(r *http.Request) error { required := []struct { name string @@ -166,6 +305,21 @@ func validateCodexHeaders(r *http.Request) error { return nil } +func captureRequest(r *http.Request, transport string, body map[string]any) CapturedRequest { + return CapturedRequest{ + Path: r.URL.Path, + Transport: transport, + Authorization: r.Header.Get("Authorization"), + OpenAIBeta: r.Header.Get("OpenAI-Beta"), + Originator: r.Header.Get("originator"), + CodexTaskType: r.Header.Get("Codex-Task-Type"), + ConversationID: r.Header.Get("conversation_id"), + SessionID: r.Header.Get("session_id"), + ChatGPTAccountID: r.Header.Get("chatgpt-account-id"), + Body: maps.Clone(body), + } +} + func (s *Server) outputText() string { if s.cfg.OutputText != "" { return s.cfg.OutputText @@ -197,17 +351,26 @@ func defaultForcedErrorJSON(status int) string { } func writeStream(w http.ResponseWriter, text string) { - created := `{"type":"response.created","sequence_number":0,"response":{"id":"resp_codex_ref","object":"response","created_at":1715620000,"status":"in_progress","model":"gpt-5.3-codex"}}` - delta := fmt.Sprintf(`{"type":"response.output_text.delta","sequence_number":1,"item_id":"msg_codex_ref","output_index":0,"content_index":0,"delta":%q}`, text) - completed := fmt.Sprintf(`{"type":"response.completed","sequence_number":2,"response":{"id":"resp_codex_ref","object":"response","created_at":1715620000,"status":"completed","model":"gpt-5.3-codex","output":[{"type":"message","id":"msg_codex_ref","status":"completed","role":"assistant","content":[{"type":"output_text","text":%q}]}],"usage":{"input_tokens":3,"output_tokens":5,"total_tokens":8}}}`, text) + frames := codexEventFrames(text) w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") w.WriteHeader(http.StatusOK) - _, _ = io.WriteString(w, "event: response.created\n") - _, _ = io.WriteString(w, "data: "+created+"\n\n") - _, _ = io.WriteString(w, "event: response.output_text.delta\n") - _, _ = io.WriteString(w, "data: "+delta+"\n\n") - _, _ = io.WriteString(w, "event: response.completed\n") - _, _ = io.WriteString(w, "data: "+completed+"\n\n") + for _, raw := range frames { + _, _ = io.WriteString(w, "event: "+sseEventNameFromFrame(raw)+"\n") + _, _ = io.WriteString(w, "data: "+raw+"\n\n") + } _, _ = io.WriteString(w, "data: [DONE]\n\n") } + +// sseEventNameFromFrame extracts the Codex Responses event "type" from a raw +// JSON frame so the SSE event name always matches the payload instead of +// relying on a parallel hardcoded list that can drift from codexEventFrames. +func sseEventNameFromFrame(raw string) string { + var probe struct { + Type string `json:"type"` + } + if err := json.Unmarshal([]byte(raw), &probe); err != nil || probe.Type == "" { + return "response.created" + } + return probe.Type +} diff --git a/internal/refbackend/openaicodex/server_test.go b/internal/refbackend/openaicodex/server_test.go index 8e3acf23..0b6a03d0 100644 --- a/internal/refbackend/openaicodex/server_test.go +++ b/internal/refbackend/openaicodex/server_test.go @@ -2,16 +2,20 @@ package openaicodex_test import ( "bufio" + "encoding/json" "io" "net/http" "net/http/httptest" + "slices" "strings" "testing" + "time" + gorillawebsocket "github.com/gorilla/websocket" refbackend "github.com/matdev83/go-llm-interactive-proxy/internal/refbackend/openaicodex" ) -const streamBody = `{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"message","role":"user","content":"hi"}]}` +const streamBody = `{"model":"gpt-5.3-codex-spark","stream":true,"input":[{"type":"message","role":"user","content":"hi"}]}` func setCodexHeaders(req *http.Request, token string) { req.Header.Set("Authorization", "Bearer "+token) @@ -89,7 +93,7 @@ func TestServer_happyPath_streamsSSEAndCapturesRequest(t *testing.T) { t.Fatalf("chatgpt-account-id: %q", got.ChatGPTAccountID) } model, ok := got.Body["model"].(string) - if !ok || model != "gpt-5.3-codex" { + if !ok || model != "gpt-5.3-codex-spark" { t.Fatalf("body model: %#v", got.Body["model"]) } } @@ -258,3 +262,92 @@ func containsSubstring(events []string, sub string) bool { } return false } + +func TestServer_webSocketUpgrade_streamsEventFramesAndCaptures(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", OutputText: "ws-ok"}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + dialer := gorillawebsocket.Dialer{HandshakeTimeout: 5 * time.Second} + hdr := http.Header{} + hdr.Set("Authorization", "Bearer sk-codex") + hdr.Set("OpenAI-Beta", "responses=experimental") + hdr.Set("originator", "lip-test") + hdr.Set("Codex-Task-Type", "code") + hdr.Set("conversation_id", "conv-ws") + hdr.Set("session_id", "sess-ws") + + wsURL := strings.Replace(ts.URL, "http://", "ws://", 1) + "/backend-api/codex/responses" + conn, resp, err := dialer.Dial(wsURL, hdr) + if err != nil { + t.Fatalf("dial: %v (resp=%v)", err, resp) + } + defer func() { _ = conn.Close() }() + + frame := `{"type":"response.create","model":"gpt-5.3-codex-spark","store":false,"input":[{"type":"message","role":"user","content":"hi"}]}` + if err := conn.WriteMessage(gorillawebsocket.TextMessage, []byte(frame)); err != nil { + t.Fatalf("write: %v", err) + } + + var types []string + for i := range 3 { + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, data, rerr := conn.ReadMessage() + if rerr != nil { + t.Fatalf("read[%d]: %v", i, rerr) + } + var ev struct { + Type string `json:"type"` + } + if err := json.Unmarshal(data, &ev); err != nil { + t.Fatalf("decode[%d]: %v: %s", i, err, data) + } + types = append(types, ev.Type) + } + if !slices.Contains(types, "response.created") || !slices.Contains(types, "response.completed") { + t.Fatalf("event types: %v", types) + } + + got := srv.LatestRequest() + if got.Transport != "websocket" { + t.Fatalf("captured transport = %q, want websocket", got.Transport) + } + if got.ConversationID != "conv-ws" { + t.Fatalf("conversation id: %q", got.ConversationID) + } + if m, _ := got.Body["model"].(string); m != "gpt-5.3-codex-spark" { + t.Fatalf("frame model: %#v", got.Body["model"]) + } + if _, hasStream := got.Body["stream"]; hasStream { + t.Fatalf("WS frame must not carry stream field: %#v", got.Body) + } +} + +func TestServer_webSocketForcedFail_closesBeforeEvents(t *testing.T) { + t.Parallel() + srv := refbackend.New(refbackend.Config{Token: "sk-codex", ForcedWSFailure: refbackend.WSFailurePolicyCloseBeforeEvent}) + ts := httptest.NewServer(srv.Handler()) + t.Cleanup(ts.Close) + + dialer := gorillawebsocket.Dialer{HandshakeTimeout: 5 * time.Second} + hdr := http.Header{} + hdr.Set("Authorization", "Bearer sk-codex") + hdr.Set("OpenAI-Beta", "responses=experimental") + hdr.Set("originator", "lip-test") + hdr.Set("Codex-Task-Type", "code") + hdr.Set("conversation_id", "conv-ws") + hdr.Set("session_id", "sess-ws") + + wsURL := strings.Replace(ts.URL, "http://", "ws://", 1) + "/backend-api/codex/responses" + conn, _, err := dialer.Dial(wsURL, hdr) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer func() { _ = conn.Close() }() + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _ = conn.WriteMessage(gorillawebsocket.TextMessage, []byte(`{"type":"response.create"}`)) + if _, _, rerr := conn.ReadMessage(); rerr == nil { + t.Fatal("expected read failure before first event") + } +} diff --git a/internal/stdhttp/default_route_frontends_test.go b/internal/stdhttp/default_route_frontends_test.go index 4b1ac757..ce3b08d5 100644 --- a/internal/stdhttp/default_route_frontends_test.go +++ b/internal/stdhttp/default_route_frontends_test.go @@ -73,6 +73,41 @@ func TestOmittedRoute_openaiResponses_usesEffectiveDefaultRoute(t *testing.T) { } } +func TestBodyRoute_openaiResponses_usesMountedRoutePrefixes(t *testing.T) { + t.Parallel() + reg := testRegistryWithStdBundle(t) + var cap sync.Map + ex := testkit.NewStubExecutor(t, lipapi.NewBackendCaps(lipapi.CapabilityStreaming), "ok", &cap) + mux := http.NewServeMux() + if err := MountBundledFrontends(MountBundledFrontendsInput{ + Mux: mux, + Exec: ex, + DefaultRouteSelector: unifiedPolicyRoute, + Plugins: []config.PluginConfig{{ID: "openai-responses", Enabled: true}}, + RoutePrefixes: []string{"stub"}, + MaxRequestBodyBytes: 0, + Reg: reg, + }); err != nil { + t.Fatal(err) + } + body := []byte(`{"model":"stub:gpt-5.5?reasoning_effort=low","stream":false,"input":[{"role":"user","content":"ping"}]}`) + req := httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rr := httptest.NewRecorder() + mux.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("status %d: %s", rr.Code, rr.Body.String()) + } + v, ok := cap.Load("last") + if !ok { + t.Fatal("expected captured call") + } + call := testkit.MustLIPCall(t, v) + if got := call.Route.Selector; got != "stub:gpt-5.5?reasoning_effort=low" { + t.Fatalf("route selector %q", got) + } +} + func TestOmittedRoute_openaiLegacy_usesEffectiveDefaultRoute(t *testing.T) { t.Parallel() reg := testRegistryWithStdBundle(t) diff --git a/internal/stdhttp/mount.go b/internal/stdhttp/mount.go index 2fc44ea9..bb340763 100644 --- a/internal/stdhttp/mount.go +++ b/internal/stdhttp/mount.go @@ -23,6 +23,7 @@ type MountBundledFrontendsInput struct { Mux *http.ServeMux Exec *runtime.Executor DefaultRouteSelector string + RoutePrefixes []string Plugins []config.PluginConfig MaxRequestBodyBytes int64 PreRequestKeepalive lipsdk.FrontendKeepaliveConfig @@ -67,6 +68,7 @@ func MountBundledFrontends(in MountBundledFrontendsInput) error { PluginCfg: p.Config, Exec: in.Exec, DefaultRoute: in.DefaultRouteSelector, + RoutePrefixes: in.RoutePrefixes, MaxRequestBodyBytes: in.MaxRequestBodyBytes, TrafficPorts: in.TrafficPorts, PreRequestKeepalive: in.PreRequestKeepalive, diff --git a/internal/stdhttp/server.go b/internal/stdhttp/server.go index 39fcb566..41595016 100644 --- a/internal/stdhttp/server.go +++ b/internal/stdhttp/server.go @@ -239,6 +239,7 @@ func prepareStandardHandler( Mux: mux, Exec: exec, DefaultRouteSelector: route, + RoutePrefixes: built.RoutePrefixes, Plugins: cfg.Plugins.Frontends, MaxRequestBodyBytes: maxBody, Reg: reg, diff --git a/internal/stdhttp/standard_wiring_roundtrip_test.go b/internal/stdhttp/standard_wiring_roundtrip_test.go index e6ccdd9b..f79998be 100644 --- a/internal/stdhttp/standard_wiring_roundtrip_test.go +++ b/internal/stdhttp/standard_wiring_roundtrip_test.go @@ -85,6 +85,7 @@ models: Exec: built.Executor, DefaultRouteSelector: route, Plugins: []config.PluginConfig{{ID: "openai-responses", Enabled: true}}, + RoutePrefixes: built.RoutePrefixes, MaxRequestBodyBytes: 0, Reg: reg, }); err != nil { diff --git a/pkg/lipapi/route_params.go b/pkg/lipapi/route_params.go index 44053d59..a0d9279b 100644 --- a/pkg/lipapi/route_params.go +++ b/pkg/lipapi/route_params.go @@ -8,8 +8,9 @@ import ( ) // MergeRouteQueryIntoGenerationOptions overlays URL query parameters from a route -// primary onto base options. Fields already set on base are left unchanged (explicit -// request / canonical call wins over route defaults). +// primary onto base options. URI params are explicit routing directives and OVERRIDE +// any corresponding value already set on base (the per-request body/call settings). +// Keys absent from the query leave the base value unchanged. // // Recognized keys (first value wins per key): temperature, top_p, max_output_tokens, // reasoning_effort, parallel_tool_calls (true/false/1/0). @@ -22,46 +23,36 @@ func MergeRouteQueryIntoGenerationOptions(base GenerationOptions, q url.Values) return out, nil } - if out.Temperature == nil { - if s := firstQuery(q, "temperature"); s != "" { - v, err := strconv.ParseFloat(s, 64) - if err != nil { - return GenerationOptions{}, fmt.Errorf("route param temperature: %w", err) - } - out.Temperature = &v + if s := firstQuery(q, "temperature"); s != "" { + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return GenerationOptions{}, fmt.Errorf("route param temperature: %w", err) } + out.Temperature = &v } - if out.TopP == nil { - if s := firstQuery(q, "top_p", "topP"); s != "" { - v, err := strconv.ParseFloat(s, 64) - if err != nil { - return GenerationOptions{}, fmt.Errorf("route param top_p: %w", err) - } - out.TopP = &v + if s := firstQuery(q, "top_p", "topP"); s != "" { + v, err := strconv.ParseFloat(s, 64) + if err != nil { + return GenerationOptions{}, fmt.Errorf("route param top_p: %w", err) } + out.TopP = &v } - if out.MaxOutputTokens == nil { - if s := firstQuery(q, "max_output_tokens", "max_tokens"); s != "" { - v, err := strconv.Atoi(s) - if err != nil { - return GenerationOptions{}, fmt.Errorf("route param max_output_tokens: %w", err) - } - out.MaxOutputTokens = &v + if s := firstQuery(q, "max_output_tokens", "max_tokens"); s != "" { + v, err := strconv.Atoi(s) + if err != nil { + return GenerationOptions{}, fmt.Errorf("route param max_output_tokens: %w", err) } + out.MaxOutputTokens = &v } - if out.ReasoningEffort == "" { - if s := firstQuery(q, "reasoning_effort"); s != "" { - out.ReasoningEffort = s - } + if s := firstQuery(q, "reasoning_effort"); s != "" { + out.ReasoningEffort = s } - if out.ParallelToolCalls == nil { - if s := firstQuery(q, "parallel_tool_calls"); s != "" { - b, err := parseBoolParam(s) - if err != nil { - return GenerationOptions{}, fmt.Errorf("route param parallel_tool_calls: %w", err) - } - out.ParallelToolCalls = &b + if s := firstQuery(q, "parallel_tool_calls"); s != "" { + b, err := parseBoolParam(s) + if err != nil { + return GenerationOptions{}, fmt.Errorf("route param parallel_tool_calls: %w", err) } + out.ParallelToolCalls = &b } if err := out.validate(); err != nil { diff --git a/pkg/lipapi/route_params_test.go b/pkg/lipapi/route_params_test.go index daa77c19..57dda247 100644 --- a/pkg/lipapi/route_params_test.go +++ b/pkg/lipapi/route_params_test.go @@ -50,19 +50,60 @@ func TestMergeRouteQueryIntoGenerationOptions_fillsFromRoute(t *testing.T) { } } -func TestMergeRouteQueryIntoGenerationOptions_callWinsOverRoute(t *testing.T) { +func TestMergeRouteQueryIntoGenerationOptions_routeOverridesCall(t *testing.T) { t.Parallel() temp := 0.1 base := lipapi.GenerationOptions{Temperature: &temp} q := url.Values{} q.Set("temperature", "0.99") + got, err := lipapi.MergeRouteQueryIntoGenerationOptions(base, q) + if err != nil { + t.Fatal(err) + } + if got.Temperature == nil || *got.Temperature != 0.99 { + t.Fatalf("route temperature should override call, got %#v", got.Temperature) + } +} + +func TestMergeRouteQueryIntoGenerationOptions_routeOverridesReasoningEffort(t *testing.T) { + t.Parallel() + // Route selectors are the user's explicit routing contract. The OpenCode/Codex + // latency bugs investigated around this connector were local protocol-shaping + // problems, not a reason to weaken URI reasoning_effort overrides. + base := lipapi.GenerationOptions{ReasoningEffort: "medium"} + q := url.Values{} + q.Set("reasoning_effort", "xhigh") + + got, err := lipapi.MergeRouteQueryIntoGenerationOptions(base, q) + if err != nil { + t.Fatal(err) + } + if got.ReasoningEffort != "xhigh" { + t.Fatalf("route reasoning_effort should override call, got %q", got.ReasoningEffort) + } +} + +func TestMergeRouteQueryIntoGenerationOptions_routeAbsentLeavesCall(t *testing.T) { + t.Parallel() + temp := 0.1 + base := lipapi.GenerationOptions{Temperature: &temp, ReasoningEffort: "medium"} + // Query sets nothing for temperature/reasoning_effort. + q := url.Values{} + q.Set("top_p", "0.5") + got, err := lipapi.MergeRouteQueryIntoGenerationOptions(base, q) if err != nil { t.Fatal(err) } if got.Temperature == nil || *got.Temperature != 0.1 { - t.Fatalf("call temperature should win, got %#v", got.Temperature) + t.Fatalf("absent route key should leave call value, got %#v", got.Temperature) + } + if got.ReasoningEffort != "medium" { + t.Fatalf("absent route key should leave call value, got %q", got.ReasoningEffort) + } + if got.TopP == nil || *got.TopP != 0.5 { + t.Fatalf("top_p: %#v", got.TopP) } } diff --git a/pkg/lipsdk/factory.go b/pkg/lipsdk/factory.go index 4ef16c34..569e1adb 100644 --- a/pkg/lipsdk/factory.go +++ b/pkg/lipsdk/factory.go @@ -28,6 +28,8 @@ type FrontendMountOptions struct { Exec ExecutorView // DefaultRoute is the selector used when the frontend protocol omits a route/header override. DefaultRoute string + // RoutePrefixes are backend route-selector prefixes accepted from protocol model fields. + RoutePrefixes []string // MaxRequestBodyBytes caps inbound HTTP request size. Zero means the frontend should use its // own default limit. MaxRequestBodyBytes int64