From a6c7b1b35d3ac803c5f7871fdac9212784637f00 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 1 Jul 2026 15:11:58 +0200 Subject: [PATCH 1/2] split codex ws and compat cleanup --- .../plugins/backends/openaicodex/stream.go | 3 + internal/plugins/backends/openaicodex/ws.go | 425 ++---------------- .../backends/openaicodex/ws_session.go | 208 +++++++++ .../plugins/backends/openaicodex/ws_stream.go | 160 +++++++ .../codexclientcompat/bridge_droid.go | 89 ++++ .../codexclientcompat/bridge_hermes.go | 61 +++ .../codexclientcompat/bridge_opencode.go | 204 +++++++++ .../features/codexclientcompat/bridge_pi.go | 64 +++ .../features/codexclientcompat/compat.go | 414 ----------------- .../plugins/frontends/anthropic/handler.go | 5 +- internal/plugins/frontends/gemini/handler.go | 5 +- .../plugins/frontends/openailegacy/handler.go | 5 +- .../frontends/openairesponses/handler.go | 5 +- 13 files changed, 838 insertions(+), 810 deletions(-) create mode 100644 internal/plugins/backends/openaicodex/ws_session.go create mode 100644 internal/plugins/backends/openaicodex/ws_stream.go create mode 100644 internal/plugins/features/codexclientcompat/bridge_droid.go create mode 100644 internal/plugins/features/codexclientcompat/bridge_hermes.go create mode 100644 internal/plugins/features/codexclientcompat/bridge_opencode.go create mode 100644 internal/plugins/features/codexclientcompat/bridge_pi.go diff --git a/internal/plugins/backends/openaicodex/stream.go b/internal/plugins/backends/openaicodex/stream.go index 8088e485..89d0d509 100644 --- a/internal/plugins/backends/openaicodex/stream.go +++ b/internal/plugins/backends/openaicodex/stream.go @@ -86,6 +86,9 @@ func looksLikeToolProtocolText(delta string) bool { if text == "" { return false } + // Treat suspected textual tool-call protocol as a stream error instead of + // dropping it silently: leaking tool syntax to the client is more damaging + // than the small false-positive risk for ordinary assistant prose. if strings.Contains(text, "to=functions.") || strings.Contains(text, "to=functions_") { return true } diff --git a/internal/plugins/backends/openaicodex/ws.go b/internal/plugins/backends/openaicodex/ws.go index 3c54ad17..8fd5c93d 100644 --- a/internal/plugins/backends/openaicodex/ws.go +++ b/internal/plugins/backends/openaicodex/ws.go @@ -6,7 +6,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "strings" "sync" @@ -20,11 +19,6 @@ import ( const wsHandshakeTimeout = 30 * time.Second -const ( - wsSessionIdleTTL = 2 * time.Minute - wsSessionMaxEntries = 256 -) - // wsFirstEventTimeout bounds the wait for the first canonical event after the // WebSocket handshake. Without it, a server that upgrades but never sends would // leave openWS blocked forever on conn.ReadMessage (which ignores ctx). It is a @@ -34,198 +28,6 @@ var wsFirstEventTimeout = 30 * time.Second var errWSPreviousResponseNotFound = errors.New("websocket previous response not found") -type wsSessionKey struct { - baseURL string - accountID string - accessToken string - conversation string -} - -type wsSessionStore struct { - mu sync.Mutex - sessions map[wsSessionKey]*wsSessionConn - idleTTL time.Duration - maxEntries int - now func() time.Time -} - -type wsSessionConn struct { - key wsSessionKey - store *wsSessionStore - sem chan struct{} - conn *websocket.Conn - lastUsed time.Time - idleTimer *time.Timer -} - -func newWSSessionStore() *wsSessionStore { - return &wsSessionStore{ - sessions: make(map[wsSessionKey]*wsSessionConn), - idleTTL: wsSessionIdleTTL, - maxEntries: wsSessionMaxEntries, - now: time.Now, - } -} - -func (s *wsSessionStore) acquire(ctx context.Context, client *http.Client, url string, cfg *Config, convID string) (*wsSessionConn, *http.Response, bool, error) { - key := wsSessionKey{ - baseURL: strings.TrimSpace(url), - accountID: strings.TrimSpace(cfg.AccountID), - accessToken: strings.TrimSpace(cfg.AccessToken), - conversation: strings.TrimSpace(convID), - } - s.mu.Lock() - session := s.sessions[key] - if session == nil { - session = &wsSessionConn{ - key: key, - store: s, - sem: make(chan struct{}, 1), - lastUsed: s.now(), - } - s.sessions[key] = session - s.pruneToCapLocked(session) - } - session.stopIdleTimerLocked() - s.mu.Unlock() - - if err := session.acquire(ctx); err != nil { - return nil, nil, false, err - } - if session.conn != nil { - return session, nil, true, nil - } - conn, resp, err := dialCodexWebSocket(ctx, client, url, cfg, convID) - if err != nil { - session.release(true) - return nil, resp, false, err - } - session.conn = conn - return session, resp, false, nil -} - -func (s *wsSessionStore) forgetLocked(key wsSessionKey, session *wsSessionConn) { - if s.sessions[key] == session { - delete(s.sessions, key) - } -} - -func (s *wsSessionStore) pruneToCapLocked(protected *wsSessionConn) { - for len(s.sessions) > s.maxEntries { - var oldestKey wsSessionKey - var oldest *wsSessionConn - for key, session := range s.sessions { - if session == protected { - continue - } - if !session.tryAcquire() { - continue - } - if oldest == nil || session.lastUsed.Before(oldest.lastUsed) { - if oldest != nil { - oldest.unlock() - } - oldestKey = key - oldest = session - continue - } - session.unlock() - } - if oldest == nil { - return - } - oldest.closeConnLocked() - s.forgetLocked(oldestKey, oldest) - oldest.unlock() - } -} - -func (s *wsSessionStore) closeIdle(key wsSessionKey, session *wsSessionConn) { - if !session.tryAcquire() { - return - } - defer session.unlock() - s.mu.Lock() - defer s.mu.Unlock() - session.closeConnLocked() - session.stopIdleTimerLocked() - s.forgetLocked(key, session) -} - -func (s *wsSessionConn) acquire(ctx context.Context) error { - if ctx == nil { - ctx = context.Background() - } - select { - case s.sem <- struct{}{}: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -func (s *wsSessionConn) tryAcquire() bool { - select { - case s.sem <- struct{}{}: - return true - default: - return false - } -} - -func (s *wsSessionConn) unlock() { - select { - case <-s.sem: - default: - } -} - -func (s *wsSessionConn) release(closeConn bool) { - if s.store == nil { - s.unlock() - return - } - s.store.mu.Lock() - if closeConn { - s.closeConnLocked() - s.stopIdleTimerLocked() - s.store.forgetLocked(s.key, s) - } else { - s.lastUsed = s.store.now() - s.scheduleIdleTimerLocked() - } - s.store.mu.Unlock() - s.unlock() -} - -func (s *wsSessionConn) closeConnLocked() { - if s.conn == nil { - return - } - _ = s.conn.Close() - s.conn = nil -} - -func (s *wsSessionConn) stopIdleTimerLocked() { - if s.idleTimer == nil { - return - } - s.idleTimer.Stop() - s.idleTimer = nil -} - -func (s *wsSessionConn) scheduleIdleTimerLocked() { - s.stopIdleTimerLocked() - if s.store == nil || s.store.idleTTL <= 0 { - return - } - key := s.key - store := s.store - s.idleTimer = time.AfterFunc(s.store.idleTTL, func() { - store.closeIdle(key, s) - }) -} - // wsEndpoint converts an HTTPS Codex base URL into the WebSocket scheme used by // the Codex Responses WebSocket transport. Path handling mirrors // responsesEndpoint so the same base_url value configures both transports. @@ -299,6 +101,38 @@ type wsOpenAttemptState struct { allowStaleRetry bool } +type wsOpenAttemptPayload struct { + env *codexOpenEnv + cfg *Config + call lipapi.Call + continuation *wsContinuationStore + fullPayload Payload + fullInputFP []string + continuationApplied bool +} + +func newWSOpenAttemptPayload(ctx context.Context, env *codexOpenEnv, cfg *Config, model string, call lipapi.Call, continuation *wsContinuationStore, state wsOpenAttemptState) wsOpenAttemptPayload { + env.payload.Model = model + fullPayload := env.payload + fullInputFP := append([]string(nil), env.inputFingerprints...) + return wsOpenAttemptPayload{ + env: env, + cfg: cfg, + call: call, + continuation: continuation, + fullPayload: fullPayload, + fullInputFP: fullInputFP, + continuationApplied: state.allowContinuation && continuation.prepareWithFingerprints(ctx, cfg, call, &env.payload, fullInputFP), + } +} + +func (p wsOpenAttemptPayload) rollback() { + if p.continuationApplied { + p.continuation.invalidateWithFingerprints(p.cfg, p.call, &p.fullPayload, p.fullInputFP) + } + p.env.payload = p.fullPayload +} + func openWSPreparedAttempt(ctx context.Context, env *codexOpenEnv, cfg *Config, model string, call lipapi.Call, usageEst *usageEstimator, sessions *wsSessionStore, continuation *wsContinuationStore) (lipapi.ManagedEventStream, *http.Response, []byte, error) { state := wsOpenAttemptState{ allowContinuation: true, @@ -325,33 +159,20 @@ func openWSPreparedAttemptOnce(ctx context.Context, env *codexOpenEnv, cfg *Conf if continuation == nil { continuation = newWSContinuationStore(codexContinuationTTL, codexContinuationMaxEntries) } - env.payload.Model = model - fullPayload := env.payload - fullInputFingerprints := append([]string(nil), env.inputFingerprints...) - continuationApplied := state.allowContinuation && continuation.prepareWithFingerprints(ctx, cfg, call, &env.payload, fullInputFingerprints) - clearPreparedContinuation := func() { - if continuationApplied { - continuation.invalidateWithFingerprints(cfg, call, &fullPayload, fullInputFingerprints) - } - } - restoreFullPayload := func() { - env.payload = fullPayload - } + attemptPayload := newWSOpenAttemptPayload(ctx, env, cfg, model, call, continuation, state) frame, err := payloadToWSResponseCreate(env.payload) if err != nil { - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, nil, wsOpenNoRetry, err } session, resp, reusedSession, err := sessions.acquire(ctx, env.client, wsEndpoint(cfg.BaseURL), cfg, env.convID) if err != nil { - clearPreparedContinuation() // Restore the full payload snapshot before returning so a rotation retry on // another account does not inherit this attempt's continuation-trimmed Input // and PreviousResponseID. The other retry paths restore below for the same // reason; the handshake-error path must too because it hands resp back to the // managed loop, which rotates accounts on 401/403/429 reusing this env. - restoreFullPayload() + attemptPayload.rollback() // Return the (body-closed) handshake response so the managed WS path can // classify 401/403/429 handshakes and rotate to the next account. return nil, resp, nil, wsOpenNoRetry, err @@ -360,12 +181,10 @@ func openWSPreparedAttemptOnce(ctx context.Context, env *codexOpenEnv, cfg *Conf if err := writeWSResponseCreate(ctx, conn, frame); err != nil { session.release(true) if reusedSession && state.allowStaleRetry { - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, nil, wsOpenRetryFreshSession, err } - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, nil, wsOpenNoRetry, err } effectiveModel := strings.TrimSpace(env.payload.Model) @@ -380,25 +199,21 @@ func openWSPreparedAttemptOnce(ctx context.Context, env *codexOpenEnv, cfg *Conf if rerr != nil { session.release(true) if reusedSession && state.allowStaleRetry && isWSFallbackError(ctx, rerr) { - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, nil, wsOpenRetryFreshSession, rerr } - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, nil, wsOpenNoRetry, rerr } if isWSFreePlanRejection(rawFirst, env.downgrade, env.originalModel) { session.release(true) - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, resp, rawFirst, wsOpenNoRetry, fmt.Errorf("%s: websocket model rejected before first event", ID) } mapper := newCodexEventMapper(call.MaxPendingWireEvents) if err := mapper.handleData(string(rawFirst)); err != nil { session.release(true) - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, rawFirst, wsOpenNoRetry, err } wsStream := newWSStreamWithMapper(conn, mapper) @@ -415,17 +230,15 @@ func openWSPreparedAttemptOnce(ctx context.Context, env *codexOpenEnv, cfg *Conf managed, rerr = openManagedUntilCommitted(ctx, wsStream, usageEst, call, effectiveModel, wsFirstEventTimeout) } if rerr != nil { - if continuationApplied && errors.Is(rerr, errWSPreviousResponseNotFound) { - continuation.invalidateWithFingerprints(cfg, call, &fullPayload, fullInputFingerprints) - restoreFullPayload() + if attemptPayload.continuationApplied && errors.Is(rerr, errWSPreviousResponseNotFound) { + attemptPayload.rollback() wsStream.releaseOnce(true) return nil, nil, rawFirst, wsOpenRetryWithoutContinuation, rerr } - clearPreparedContinuation() - restoreFullPayload() + attemptPayload.rollback() return nil, nil, rawFirst, wsOpenNoRetry, wsPreFirstEventFailure(rerr) } - managed = newCodexContinuationRecordingStream(managed, cfg, call, fullPayload, fullInputFingerprints, mapper, continuation) + managed = newCodexContinuationRecordingStream(managed, cfg, call, attemptPayload.fullPayload, attemptPayload.fullInputFP, mapper, continuation) // The opening boundary has been reached: strict websocket mode returns after // the first canonical event, while auto mode waits until output is committed // or terminal. Clear the deadline so subsequent streaming reads are governed @@ -658,151 +471,3 @@ func payloadToWSResponseCreate(p Payload) (json.RawMessage, error) { } return out, nil } - -var _ lipapi.ManagedEventStream = (*wsStream)(nil) - -type wsStream struct { - mapper *codexEventMapper - mu sync.Mutex - conn *websocket.Conn - closed bool - release func(closeConn bool) - releaseOnceF sync.Once -} - -func newWSStream(conn *websocket.Conn, maxPending int) *wsStream { - return newWSStreamWithMapper(conn, newCodexEventMapper(maxPending)) -} - -// newWSStreamWithMapper builds a wsStream over a pre-existing event mapper. The caller -// may have already populated the mapper's pending queue (e.g. from a pre-read first -// frame); the stream's Recv drains pending before reading the next wire frame. -func newWSStreamWithMapper(conn *websocket.Conn, mapper *codexEventMapper) *wsStream { - return &wsStream{ - mapper: mapper, - conn: conn, - } -} - -func (s *wsStream) Recv(ctx context.Context) (lipapi.Event, error) { - if ctx == nil { - return lipapi.Event{}, lipapi.ErrNilContext - } - if err := ctx.Err(); err != nil { - return lipapi.Event{}, err - } - for { - s.mu.Lock() - if s.closed { - s.mu.Unlock() - return lipapi.Event{}, io.EOF - } - if ev, ok := s.mapper.pending.PopFront(); ok { - s.mu.Unlock() - return ev, nil - } - if s.mapper.terminal { - s.mu.Unlock() - s.releaseOnce(false) - return lipapi.Event{}, io.EOF - } - s.mu.Unlock() - - text, ok, err := s.readMessage(ctx) - if err != nil { - s.mu.Lock() - closed := s.closed - s.mu.Unlock() - if closed { - return lipapi.Event{}, io.EOF - } - s.releaseOnce(true) - return lipapi.Event{}, err - } - if !ok { - s.mu.Lock() - closed := s.closed - s.mu.Unlock() - if closed { - return lipapi.Event{}, io.EOF - } - s.releaseOnce(false) - return lipapi.Event{}, io.EOF - } - if text == "" { - continue - } - - s.mu.Lock() - if s.closed { - s.mu.Unlock() - continue - } - if err := s.mapper.handleData(text); err != nil { - s.mu.Unlock() - return lipapi.Event{}, err - } - s.mu.Unlock() - } -} - -func (s *wsStream) readMessage(ctx context.Context) (string, bool, error) { - stopCancel := context.AfterFunc(ctx, func() { - _ = s.conn.SetReadDeadline(time.Now()) - }) - defer stopCancel() - _, data, err := s.conn.ReadMessage() - if err != nil { - if ctxErr := ctx.Err(); ctxErr != nil { - _ = s.conn.SetReadDeadline(time.Time{}) - return "", false, ctxErr - } - if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - return "", false, io.EOF - } - return "", false, newWSStreamReadError(err) - } - text := strings.TrimSpace(string(data)) - if text == "" { - return "", true, nil - } - return text, true, nil -} - -func (s *wsStream) Close() error { - closeConn := true - s.mu.Lock() - if s.mapper.terminal { - closeConn = false - } - s.mu.Unlock() - if closeConn && s.conn != nil { - // Close first, without taking s.mu: Recv holds that lock while blocked in - // ReadMessage, so taking it before closing would deadlock cancellation. - _ = s.conn.Close() - } - s.mu.Lock() - defer s.mu.Unlock() - if s.closed { - return nil - } - s.closed = true - s.releaseOnce(closeConn) - return nil -} - -func (s *wsStream) Cancel(context.Context, lipapi.CancelCause) lipapi.CancelResult { - // Codex WebSocket does not have a request-cancel frame in this adapter. Close - // the socket instead of pretending cancellation is protocol-level; this also - // prevents reuse of an in-flight session whose upstream generation may still be - // producing frames. - return lipapi.CancelResult{Mode: lipapi.CancelModeCloseOnly, Err: s.Close()} -} - -func (s *wsStream) releaseOnce(closeConn bool) { - s.releaseOnceF.Do(func() { - if s.release != nil { - s.release(closeConn) - } - }) -} diff --git a/internal/plugins/backends/openaicodex/ws_session.go b/internal/plugins/backends/openaicodex/ws_session.go new file mode 100644 index 00000000..99a626a1 --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws_session.go @@ -0,0 +1,208 @@ +package openaicodex + +import ( + "context" + "net/http" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +const ( + wsSessionIdleTTL = 2 * time.Minute + wsSessionMaxEntries = 256 +) + +type wsSessionKey struct { + baseURL string + accountID string + accessToken string + conversation string +} + +type wsSessionStore struct { + mu sync.Mutex + sessions map[wsSessionKey]*wsSessionConn + idleTTL time.Duration + maxEntries int + now func() time.Time +} + +type wsSessionConn struct { + key wsSessionKey + store *wsSessionStore + sem chan struct{} + conn *websocket.Conn + lastUsed time.Time + idleTimer *time.Timer +} + +func newWSSessionStore() *wsSessionStore { + return &wsSessionStore{ + sessions: make(map[wsSessionKey]*wsSessionConn), + idleTTL: wsSessionIdleTTL, + maxEntries: wsSessionMaxEntries, + now: time.Now, + } +} + +func (s *wsSessionStore) acquire(ctx context.Context, client *http.Client, url string, cfg *Config, convID string) (*wsSessionConn, *http.Response, bool, error) { + key := wsSessionKey{ + baseURL: strings.TrimSpace(url), + accountID: strings.TrimSpace(cfg.AccountID), + accessToken: strings.TrimSpace(cfg.AccessToken), + conversation: strings.TrimSpace(convID), + } + s.mu.Lock() + session := s.sessions[key] + if session == nil { + session = &wsSessionConn{ + key: key, + store: s, + sem: make(chan struct{}, 1), + lastUsed: s.now(), + } + s.sessions[key] = session + s.pruneToCapLocked(session) + } + session.stopIdleTimerLocked() + s.mu.Unlock() + + if err := session.acquire(ctx); err != nil { + return nil, nil, false, err + } + if session.conn != nil { + return session, nil, true, nil + } + conn, resp, err := dialCodexWebSocket(ctx, client, url, cfg, convID) + if err != nil { + session.release(true) + return nil, resp, false, err + } + session.conn = conn + return session, resp, false, nil +} + +func (s *wsSessionStore) forgetLocked(key wsSessionKey, session *wsSessionConn) { + if s.sessions[key] == session { + delete(s.sessions, key) + } +} + +func (s *wsSessionStore) pruneToCapLocked(protected *wsSessionConn) { + for len(s.sessions) > s.maxEntries { + var oldestKey wsSessionKey + var oldest *wsSessionConn + for key, session := range s.sessions { + if session == protected { + continue + } + if !session.tryAcquire() { + continue + } + if oldest == nil || session.lastUsed.Before(oldest.lastUsed) { + if oldest != nil { + oldest.unlock() + } + oldestKey = key + oldest = session + continue + } + session.unlock() + } + if oldest == nil { + return + } + oldest.closeConnLocked() + s.forgetLocked(oldestKey, oldest) + oldest.unlock() + } +} + +func (s *wsSessionStore) closeIdle(key wsSessionKey, session *wsSessionConn) { + if !session.tryAcquire() { + return + } + defer session.unlock() + s.mu.Lock() + defer s.mu.Unlock() + session.closeConnLocked() + session.stopIdleTimerLocked() + s.forgetLocked(key, session) +} + +func (s *wsSessionConn) acquire(ctx context.Context) error { + if ctx == nil { + ctx = context.Background() + } + select { + case s.sem <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *wsSessionConn) tryAcquire() bool { + select { + case s.sem <- struct{}{}: + return true + default: + return false + } +} + +func (s *wsSessionConn) unlock() { + select { + case <-s.sem: + default: + } +} + +func (s *wsSessionConn) release(closeConn bool) { + if s.store == nil { + s.unlock() + return + } + s.store.mu.Lock() + if closeConn { + s.closeConnLocked() + s.stopIdleTimerLocked() + s.store.forgetLocked(s.key, s) + } else { + s.lastUsed = s.store.now() + s.scheduleIdleTimerLocked() + } + s.store.mu.Unlock() + s.unlock() +} + +func (s *wsSessionConn) closeConnLocked() { + if s.conn == nil { + return + } + _ = s.conn.Close() + s.conn = nil +} + +func (s *wsSessionConn) stopIdleTimerLocked() { + if s.idleTimer == nil { + return + } + s.idleTimer.Stop() + s.idleTimer = nil +} + +func (s *wsSessionConn) scheduleIdleTimerLocked() { + s.stopIdleTimerLocked() + if s.store == nil || s.store.idleTTL <= 0 { + return + } + key := s.key + store := s.store + s.idleTimer = time.AfterFunc(s.store.idleTTL, func() { + store.closeIdle(key, s) + }) +} diff --git a/internal/plugins/backends/openaicodex/ws_stream.go b/internal/plugins/backends/openaicodex/ws_stream.go new file mode 100644 index 00000000..f34006b4 --- /dev/null +++ b/internal/plugins/backends/openaicodex/ws_stream.go @@ -0,0 +1,160 @@ +package openaicodex + +import ( + "context" + "io" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +var _ lipapi.ManagedEventStream = (*wsStream)(nil) + +type wsStream struct { + mapper *codexEventMapper + mu sync.Mutex + conn *websocket.Conn + closed bool + release func(closeConn bool) + releaseOnceF sync.Once +} + +func newWSStream(conn *websocket.Conn, maxPending int) *wsStream { + return newWSStreamWithMapper(conn, newCodexEventMapper(maxPending)) +} + +// newWSStreamWithMapper builds a wsStream over a pre-existing event mapper. The caller +// may have already populated the mapper's pending queue (e.g. from a pre-read first +// frame); the stream's Recv drains pending before reading the next wire frame. +func newWSStreamWithMapper(conn *websocket.Conn, mapper *codexEventMapper) *wsStream { + return &wsStream{ + mapper: mapper, + conn: conn, + } +} + +func (s *wsStream) Recv(ctx context.Context) (lipapi.Event, error) { + if ctx == nil { + return lipapi.Event{}, lipapi.ErrNilContext + } + if err := ctx.Err(); err != nil { + return lipapi.Event{}, err + } + for { + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return lipapi.Event{}, io.EOF + } + if ev, ok := s.mapper.pending.PopFront(); ok { + s.mu.Unlock() + return ev, nil + } + if s.mapper.terminal { + s.mu.Unlock() + s.releaseOnce(false) + return lipapi.Event{}, io.EOF + } + s.mu.Unlock() + + text, ok, err := s.readMessage(ctx) + if err != nil { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return lipapi.Event{}, io.EOF + } + s.releaseOnce(true) + return lipapi.Event{}, err + } + if !ok { + s.mu.Lock() + closed := s.closed + s.mu.Unlock() + if closed { + return lipapi.Event{}, io.EOF + } + s.releaseOnce(false) + return lipapi.Event{}, io.EOF + } + if text == "" { + continue + } + + s.mu.Lock() + if s.closed { + s.mu.Unlock() + continue + } + if err := s.mapper.handleData(text); err != nil { + s.mu.Unlock() + return lipapi.Event{}, err + } + s.mu.Unlock() + } +} + +func (s *wsStream) readMessage(ctx context.Context) (string, bool, error) { + stopCancel := context.AfterFunc(ctx, func() { + _ = s.conn.SetReadDeadline(time.Now()) + }) + defer stopCancel() + _, data, err := s.conn.ReadMessage() + if err != nil { + if ctxErr := ctx.Err(); ctxErr != nil { + _ = s.conn.SetReadDeadline(time.Time{}) + return "", false, ctxErr + } + if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return "", false, io.EOF + } + return "", false, newWSStreamReadError(err) + } + text := strings.TrimSpace(string(data)) + if text == "" { + return "", true, nil + } + return text, true, nil +} + +func (s *wsStream) Close() error { + closeConn := true + s.mu.Lock() + if s.mapper.terminal { + closeConn = false + } + s.mu.Unlock() + if closeConn && s.conn != nil { + // Close first, without taking s.mu: Recv holds that lock while blocked in + // ReadMessage, so taking it before closing would deadlock cancellation. + _ = s.conn.Close() + } + s.mu.Lock() + defer s.mu.Unlock() + if s.closed { + return nil + } + s.closed = true + s.releaseOnce(closeConn) + return nil +} + +func (s *wsStream) Cancel(context.Context, lipapi.CancelCause) lipapi.CancelResult { + // Codex WebSocket does not have a request-cancel frame in this adapter. Close + // the socket instead of pretending cancellation is protocol-level; this also + // prevents reuse of an in-flight session whose upstream generation may still be + // producing frames. + return lipapi.CancelResult{Mode: lipapi.CancelModeCloseOnly, Err: s.Close()} +} + +func (s *wsStream) releaseOnce(closeConn bool) { + s.releaseOnceF.Do(func() { + if s.release != nil { + s.release(closeConn) + } + }) +} diff --git a/internal/plugins/features/codexclientcompat/bridge_droid.go b/internal/plugins/features/codexclientcompat/bridge_droid.go new file mode 100644 index 00000000..9b50943f --- /dev/null +++ b/internal/plugins/features/codexclientcompat/bridge_droid.go @@ -0,0 +1,89 @@ +package codexclientcompat + +import ( + "slices" + "sort" + "strings" +) + +var ( + droidNativeToolNames = map[string]struct{}{ + "Read": {}, "LS": {}, "Execute": {}, "Edit": {}, "Grep": {}, "Glob": {}, + "Create": {}, "TodoWrite": {}, "WebSearch": {}, "FetchUrl": {}, "ExitSpecMode": {}, + } + droidSystemPromptKeywords = []string{ + "you are droid", + "droid, an ai", + "factory droid", + } + droidUserAgentTokens = []string{"factory-cli", "factory_cli", "factorydroid", "droid"} +) + +func droidAgentMatch(in compatInput) bool { + return slices.ContainsFunc(in.agents, droidUserAgentMatch) +} + +func droidPromptMatch(in compatInput) bool { + lower := strings.ToLower(in.prompt) + for _, keyword := range droidSystemPromptKeywords { + if strings.Contains(lower, keyword) { + return true + } + } + return false +} + +func droidUserAgentMatch(userAgent string) bool { + lower := strings.ToLower(userAgent) + for _, pattern := range droidUserAgentTokens { + if strings.Contains(lower, pattern) { + return true + } + } + return false +} + +func isDroidHarnessText(text string) bool { + lower := strings.ToLower(text) + return strings.Contains(lower, "factory droid") && strings.Contains(lower, "execute") && strings.Contains(lower, "todowrite") +} + +func buildDroidBridge(availableTools []string) string { + native := sortedNativeDroidTools() + available := availableTools + if len(available) == 0 { + available = native + } + availableText := joinBacktickList(available) + nativeText := joinBacktickList(native) + return droidBridgeMarker + ":\n" + + "- This session is using Factory Droid tools, not Codex-native tools.\n" + + "- Use only tool names that are actually available in this session: " + availableText + ".\n" + + "- Prefer the native Factory Droid tool family when available: " + nativeText + ".\n" + + "- Use Droid argument shapes exactly for the native file/execute tools: `Read(file_path, offset?, limit?)`, `LS(directory_path?)`, `Execute(command, timeout?, cwd?)`, `Edit(file_path, old_str, new_str)`, `Grep(pattern, path?, file_pattern?, max_results?)`, `Glob(pattern, max_results?)`, `Create(file_path, content)`.\n" + + "- Do not emit Codex-native tool names such as `read`, `read_file`, `bash`, `shell`, `apply_patch`, `grep_files`, or `list_dir`.\n" + + "- Use `TodoWrite` instead of Codex task-planner tools, `WebSearch` for web search, and `FetchUrl` for direct URL fetches when those tools are available.\n" + + "- Keep tool arguments as JSON objects; for `Execute`, the `command` value must be a single shell command string, not an array.\n" + + "\n" + + criticalInstruction("Droid") +} + +func sortedNativeDroidTools() []string { + out := make([]string, 0, len(droidNativeToolNames)) + for name := range droidNativeToolNames { + out = append(out, name) + } + sort.Strings(out) + return out +} + +func joinBacktickList(names []string) string { + if len(names) == 0 { + return "" + } + parts := make([]string, len(names)) + for i, name := range names { + parts[i] = "`" + name + "`" + } + return strings.Join(parts, ", ") +} diff --git a/internal/plugins/features/codexclientcompat/bridge_hermes.go b/internal/plugins/features/codexclientcompat/bridge_hermes.go new file mode 100644 index 00000000..78245254 --- /dev/null +++ b/internal/plugins/features/codexclientcompat/bridge_hermes.go @@ -0,0 +1,61 @@ +package codexclientcompat + +import ( + "encoding/json" + "strings" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +// hermesIdentitySentence is the exact upstream Hermes Agent identity sentence. +const hermesIdentitySentence = "You are Hermes Agent, an intelligent AI assistant created by Nous Research." + +var hermesUserAgentMarkers = []string{ + "hermes-agent", + "nousresearch/hermes-agent", + "hermes/", +} + +func hermesAgentMatch(in compatInput) bool { + for _, candidate := range in.agents { + lower := strings.ToLower(candidate) + for _, marker := range hermesUserAgentMarkers { + if strings.Contains(lower, marker) { + return true + } + } + } + return false +} + +func hermesPromptMatch(in compatInput) bool { + return strings.Contains(strings.ToLower(in.prompt), strings.ToLower(hermesIdentitySentence)) +} + +func isHermesBridgeText(text string) bool { + return strings.Contains(text, hermesBridgeMarker) +} + +func buildHermesBridge() string { + return hermesBridgeMarker + ":\n" + + "- Preserve the Hermes Agent identity and system prompt; do not replace or restate it as Codex.\n" + + "- Use structured function/tool calls for every action; never inline textual " + + "`to=functions.` or Harmony-style tool calls in assistant content.\n" + + "- Continue using the available tools until the task is complete and verified.\n" + + "- Perform prerequisite lookup and discovery (files, symbols, context) with tools before acting.\n" + + "- When retrievable context is missing, fetch it with available tools; do not guess or fabricate it.\n" + + "\n" + + "CRITICAL INSTRUCTION:\n" + + "(a) Keep the Hermes identity/system prompt intact; append compatibility guidance, never overwrite it.\n" + + "(b) Never emit textual tool-call syntax (`to=functions.`, Harmony calls) in assistant content; use structured tool calls only." +} + +func applyHermesToolStrict(call *lipapi.Call) { + if call.Extensions == nil { + call.Extensions = map[string]json.RawMessage{} + } + if _, ok := call.Extensions[extCodexToolStrictKey]; ok { + return + } + call.Extensions[extCodexToolStrictKey] = json.RawMessage("false") +} diff --git a/internal/plugins/features/codexclientcompat/bridge_opencode.go b/internal/plugins/features/codexclientcompat/bridge_opencode.go new file mode 100644 index 00000000..8d72bb2e --- /dev/null +++ b/internal/plugins/features/codexclientcompat/bridge_opencode.go @@ -0,0 +1,204 @@ +package codexclientcompat + +import ( + "encoding/json" + "strings" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func openCodeAgentMatch(in compatInput) bool { + for _, candidate := range in.agents { + if strings.Contains(strings.ToLower(candidate), "opencode") { + return true + } + } + return false +} + +func openCodePromptMatch(in compatInput) bool { + lower := strings.ToLower(in.prompt) + if strings.Contains(lower, "opencode") { + if strings.Contains(lower, "compatibility") || strings.Contains(lower, "harness") || strings.Contains(lower, "tool") { + return true + } + } + return false +} + +func isOpenCodeHarnessText(text string) bool { + lower := strings.ToLower(text) + return strings.Contains(lower, "opencode") && strings.Contains(lower, "tool") +} + +func applyOpenCodeToolHistoryCompat(call *lipapi.Call) { + convertOrphanedToolResults(call) +} + +func hasStructuredToolTranscript(msgs []lipapi.Message) bool { + for _, m := range msgs { + if m.Role == lipapi.RoleTool { + for _, p := range m.Parts { + if p.Kind == lipapi.PartToolResult { + return true + } + } + } + if m.Role != lipapi.RoleAssistant { + continue + } + for _, p := range m.Parts { + if p.Kind != lipapi.PartJSON { + continue + } + if isFunctionCallPart(p) { + return true + } + } + } + return false +} + +func isFunctionCallPart(p lipapi.Part) bool { + if len(p.Content) == 0 { + return false + } + var fc struct { + Type string `json:"type"` + CallID string `json:"call_id"` + ID string `json:"id"` + Name string `json:"name"` + Function *struct { + Name string `json:"name"` + } `json:"function"` + } + if json.Unmarshal(p.Content, &fc) != nil { + return false + } + if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { + return false + } + id := firstNonEmpty(fc.CallID, fc.ID) + name := strings.TrimSpace(fc.Name) + if name == "" && fc.Function != nil { + name = strings.TrimSpace(fc.Function.Name) + } + return strings.TrimSpace(id) != "" && name != "" +} + +func convertOrphanedToolResults(call *lipapi.Call) { + known := collectKnownToolCallIDs(call.Messages) + out := make([]lipapi.Message, 0, len(call.Messages)) + for _, m := range call.Messages { + if m.Role != lipapi.RoleTool { + out = append(out, m) + continue + } + kept := make([]lipapi.Part, 0, len(m.Parts)) + for _, p := range m.Parts { + if p.Kind != lipapi.PartToolResult { + kept = append(kept, p) + continue + } + callID := strings.TrimSpace(p.ToolCallID) + if callID != "" { + if _, ok := known[callID]; ok { + kept = append(kept, p) + continue + } + } + out = append(out, convertOrphanedToolResult(p)) + } + if len(kept) > 0 { + out = append(out, lipapi.Message{Role: lipapi.RoleTool, Parts: kept}) + } + } + call.Messages = out +} + +func collectKnownToolCallIDs(msgs []lipapi.Message) map[string]struct{} { + known := make(map[string]struct{}) + for _, m := range msgs { + if m.Role != lipapi.RoleAssistant { + continue + } + for _, p := range m.Parts { + if p.Kind != lipapi.PartJSON { + continue + } + var fc struct { + Type string `json:"type"` + CallID string `json:"call_id"` + ID string `json:"id"` + } + if json.Unmarshal(p.Content, &fc) != nil { + continue + } + // Accept Responses-style ("function_call") and Chat Completions-style + // ("function") assistant tool calls so matching tool results are preserved. + if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { + continue + } + id := strings.TrimSpace(fc.CallID) + if id == "" { + id = strings.TrimSpace(fc.ID) + } + if id != "" { + known[id] = struct{}{} + } + } + } + return known +} + +func convertOrphanedToolResult(p lipapi.Part) lipapi.Message { + rendered := string(p.Content) + if len(p.Content) == 0 { + rendered = "" + } + header := "Prior tool output (original tool call reference unavailable)." + if id := strings.TrimSpace(p.ToolCallID); id != "" { + header += " call_id=" + id + "." + } + return lipapi.Message{ + Role: lipapi.RoleSystem, + Parts: []lipapi.Part{lipapi.TextPart(header + "\n" + rendered)}, + } +} + +func buildOpenCodeBridge(hasTools bool) string { + var b strings.Builder + b.WriteString(openCodeBridgeMarker) + b.WriteString(":\n") + if hasTools { + // Keep this guidance generic. OpenCode tool names and schemas vary by + // installation, plugin, and session; the structured tool list is the only + // authoritative source of callable names. Duplicating names in prose makes + // random session-specific tools look universal and can bias the model toward + // tools the current request did not actually expose. + b.WriteString("- Prefer the available client shell tool when command execution is needed.\n") + } else { + b.WriteString("- No callable client tools are available in this request. Do not attempt tool calls; respond in plain text or ask the user/client to provide tools.\n") + } + b.WriteString("- Never emit textual tool-call syntax such as `to=functions.` or JSON tool calls in assistant content; use structured tool calls only when tools are available.\n") + if !hasTools { + // No tools are exposed, so do not append criticalInstruction("OpenCode"): + // it tells the model to use agent-provided tools, contradicting the + // "no callable client tools" guidance above and risking spurious tool calls. + return b.String() + } + b.WriteString( + "- For bash-style tools, arguments MUST be a JSON object with string " + + "`command` and string `description`.\n" + + "- Bash-style tools MAY include numeric `timeout` in milliseconds " + + "and string `workdir` when the client schema exposes them.\n" + + "- Never emit array-valued `command` arguments for shell execution.\n" + + "- Do not use `apply_patch`; use the client's native file editing tools instead.\n" + + "- Do not use `update_plan` or `read_plan`; use the client's task tools instead.\n" + + "- If you need a working directory, prefer `workdir` over `cd` commands " + + "or embedding cwd text in `description`.\n" + + "\n" + + criticalInstruction("OpenCode"), + ) + return b.String() +} diff --git a/internal/plugins/features/codexclientcompat/bridge_pi.go b/internal/plugins/features/codexclientcompat/bridge_pi.go new file mode 100644 index 00000000..71aa07d7 --- /dev/null +++ b/internal/plugins/features/codexclientcompat/bridge_pi.go @@ -0,0 +1,64 @@ +package codexclientcompat + +import "strings" + +var ( + piPromptMarkers = []string{ + "operating inside pi", + "coding agent harness", + "available tools:", + "in addition to the tools above", + "guidelines:", + } + piUserAgentMarkers = []string{ + "@mariozechner/pi-coding-agent", + " pi/", + "pi-coding-agent", + } +) + +func piAgentMatch(in compatInput) bool { + for _, candidate := range in.agents { + lower := strings.ToLower(candidate) + for _, marker := range piUserAgentMarkers { + if strings.Contains(lower, marker) { + return true + } + } + } + return false +} + +func piPromptMatch(in compatInput) bool { + lower := strings.ToLower(in.prompt) + hits := 0 + for _, marker := range piPromptMarkers { + if strings.Contains(lower, marker) { + hits++ + } + } + return hits >= 2 +} + +func isPiHarnessText(text string) bool { + lower := strings.ToLower(text) + for _, marker := range piPromptMarkers { + if strings.Contains(lower, marker) { + return true + } + } + return false +} + +func buildPiBridge() string { + return piBridgeMarker + ":\n" + + "- Use only tools exposed by the pi client for this session.\n" + + "- Use `bash` for terminal execution with a JSON object containing string `command` and optional numeric `timeout` in seconds; pi has no default timeout.\n" + + "- Do not emit `shell`, `local_shell_call`, or array-valued shell commands; pi expects the `bash` tool name.\n" + + "- Do not use `apply_patch`; use pi's `edit` tool for exact text replacement in a single file.\n" + + "- For `edit`, pass `path` plus an `edits` array of replacements with `oldText` and `newText`, each matched against the original file.\n" + + "- For file reads use `read` with `path` and optional `offset`/`limit`; for full rewrites use `write` with `path` and `content`.\n" + + "- Keep responses concise and show file paths clearly.\n" + + "\n" + + criticalInstruction("pi") +} diff --git a/internal/plugins/features/codexclientcompat/compat.go b/internal/plugins/features/codexclientcompat/compat.go index 937d1be8..05082ef1 100644 --- a/internal/plugins/features/codexclientcompat/compat.go +++ b/internal/plugins/features/codexclientcompat/compat.go @@ -2,8 +2,6 @@ package codexclientcompat import ( "encoding/json" - "slices" - "sort" "strings" "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" @@ -22,42 +20,9 @@ const ( extCodexToolStrictKey = "openai_codex.tool_strict" extCodexIgnoreUnsupportedGenParamsKey = "openai_codex.ignore_unsupported_gen_params" - // hermesIdentitySentence is the exact upstream Hermes Agent identity sentence. - hermesIdentitySentence = "You are Hermes Agent, an intelligent AI assistant created by Nous Research." - codexDefaultInstruction = "You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer." ) -var ( - piPromptMarkers = []string{ - "operating inside pi", - "coding agent harness", - "available tools:", - "in addition to the tools above", - "guidelines:", - } - piUserAgentMarkers = []string{ - "@mariozechner/pi-coding-agent", - " pi/", - "pi-coding-agent", - } - droidNativeToolNames = map[string]struct{}{ - "Read": {}, "LS": {}, "Execute": {}, "Edit": {}, "Grep": {}, "Glob": {}, - "Create": {}, "TodoWrite": {}, "WebSearch": {}, "FetchUrl": {}, "ExitSpecMode": {}, - } - droidSystemPromptKeywords = []string{ - "you are droid", - "droid, an ai", - "factory droid", - } - droidUserAgentTokens = []string{"factory-cli", "factory_cli", "factorydroid", "droid"} - hermesUserAgentMarkers = []string{ - "hermes-agent", - "nousresearch/hermes-agent", - "hermes/", - } -) - type compatInput struct { agents []string prompt string @@ -159,61 +124,6 @@ func fallbackCompatBridge(call *lipapi.Call) *compatBridge { return nil } -func applyOpenCodeToolHistoryCompat(call *lipapi.Call) { - convertOrphanedToolResults(call) -} - -func hasStructuredToolTranscript(msgs []lipapi.Message) bool { - for _, m := range msgs { - if m.Role == lipapi.RoleTool { - for _, p := range m.Parts { - if p.Kind == lipapi.PartToolResult { - return true - } - } - } - if m.Role != lipapi.RoleAssistant { - continue - } - for _, p := range m.Parts { - if p.Kind != lipapi.PartJSON { - continue - } - if isFunctionCallPart(p) { - return true - } - } - } - return false -} - -func isFunctionCallPart(p lipapi.Part) bool { - if len(p.Content) == 0 { - return false - } - var fc struct { - Type string `json:"type"` - CallID string `json:"call_id"` - ID string `json:"id"` - Name string `json:"name"` - Function *struct { - Name string `json:"name"` - } `json:"function"` - } - if json.Unmarshal(p.Content, &fc) != nil { - return false - } - if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { - return false - } - id := firstNonEmpty(fc.CallID, fc.ID) - name := strings.TrimSpace(fc.Name) - if name == "" && fc.Function != nil { - name = strings.TrimSpace(fc.Function.Name) - } - return strings.TrimSpace(id) != "" && name != "" -} - func firstNonEmpty(values ...string) string { for _, value := range values { if strings.TrimSpace(value) != "" { @@ -287,88 +197,6 @@ func collectCallToolNames(call *lipapi.Call) []string { return out } -func openCodeAgentMatch(in compatInput) bool { - for _, candidate := range in.agents { - if strings.Contains(strings.ToLower(candidate), "opencode") { - return true - } - } - return false -} - -func openCodePromptMatch(in compatInput) bool { - lower := strings.ToLower(in.prompt) - if strings.Contains(lower, "opencode") { - if strings.Contains(lower, "compatibility") || strings.Contains(lower, "harness") || strings.Contains(lower, "tool") { - return true - } - } - return false -} - -func piAgentMatch(in compatInput) bool { - for _, candidate := range in.agents { - lower := strings.ToLower(candidate) - for _, marker := range piUserAgentMarkers { - if strings.Contains(lower, marker) { - return true - } - } - } - return false -} - -func piPromptMatch(in compatInput) bool { - lower := strings.ToLower(in.prompt) - hits := 0 - for _, marker := range piPromptMarkers { - if strings.Contains(lower, marker) { - hits++ - } - } - return hits >= 2 -} - -func droidAgentMatch(in compatInput) bool { - return slices.ContainsFunc(in.agents, droidUserAgentMatch) -} - -func droidPromptMatch(in compatInput) bool { - lower := strings.ToLower(in.prompt) - for _, keyword := range droidSystemPromptKeywords { - if strings.Contains(lower, keyword) { - return true - } - } - return false -} - -func droidUserAgentMatch(userAgent string) bool { - lower := strings.ToLower(userAgent) - for _, pattern := range droidUserAgentTokens { - if strings.Contains(lower, pattern) { - return true - } - } - return false -} - -func hermesAgentMatch(in compatInput) bool { - for _, candidate := range in.agents { - lower := strings.ToLower(candidate) - for _, marker := range hermesUserAgentMarkers { - if strings.Contains(lower, marker) { - return true - } - } - } - return false -} - -func hermesPromptMatch(in compatInput) bool { - return strings.Contains(strings.ToLower(in.prompt), strings.ToLower(hermesIdentitySentence)) -} - func filterHarnessMessages(msgs []lipapi.Message, isHarness func(string) bool) []lipapi.Message { if len(msgs) == 0 { return msgs @@ -389,30 +217,6 @@ func messageText(m lipapi.Message) string { return b.String() } -func isOpenCodeHarnessText(text string) bool { - lower := strings.ToLower(text) - return strings.Contains(lower, "opencode") && strings.Contains(lower, "tool") -} - -func isPiHarnessText(text string) bool { - lower := strings.ToLower(text) - for _, marker := range piPromptMarkers { - if strings.Contains(lower, marker) { - return true - } - } - return false -} - -func isDroidHarnessText(text string) bool { - lower := strings.ToLower(text) - return strings.Contains(lower, "factory droid") && strings.Contains(lower, "execute") && strings.Contains(lower, "todowrite") -} - -func isHermesBridgeText(text string) bool { - return strings.Contains(text, hermesBridgeMarker) -} - func joinInstructionText(insts []lipapi.Message) string { var b strings.Builder for _, m := range insts { @@ -473,180 +277,6 @@ func appendCompatInstructions(instructions, marker, block string) string { return block } -func convertOrphanedToolResults(call *lipapi.Call) { - known := collectKnownToolCallIDs(call.Messages) - out := make([]lipapi.Message, 0, len(call.Messages)) - for _, m := range call.Messages { - if m.Role != lipapi.RoleTool { - out = append(out, m) - continue - } - kept := make([]lipapi.Part, 0, len(m.Parts)) - for _, p := range m.Parts { - if p.Kind != lipapi.PartToolResult { - kept = append(kept, p) - continue - } - callID := strings.TrimSpace(p.ToolCallID) - if callID != "" { - if _, ok := known[callID]; ok { - kept = append(kept, p) - continue - } - } - out = append(out, convertOrphanedToolResult(p)) - } - if len(kept) > 0 { - out = append(out, lipapi.Message{Role: lipapi.RoleTool, Parts: kept}) - } - } - call.Messages = out -} - -func collectKnownToolCallIDs(msgs []lipapi.Message) map[string]struct{} { - known := make(map[string]struct{}) - for _, m := range msgs { - if m.Role != lipapi.RoleAssistant { - continue - } - for _, p := range m.Parts { - if p.Kind != lipapi.PartJSON { - continue - } - var fc struct { - Type string `json:"type"` - CallID string `json:"call_id"` - ID string `json:"id"` - } - if json.Unmarshal(p.Content, &fc) != nil { - continue - } - // Accept Responses-style ("function_call") and Chat Completions-style - // ("function") assistant tool calls so matching tool results are preserved. - if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { - continue - } - id := strings.TrimSpace(fc.CallID) - if id == "" { - id = strings.TrimSpace(fc.ID) - } - if id != "" { - known[id] = struct{}{} - } - } - } - return known -} - -func argumentText(raw json.RawMessage) string { - if len(raw) == 0 { - return "" - } - if raw[0] == '"' { - var s string - if json.Unmarshal(raw, &s) == nil { - return s - } - } - return string(raw) -} - -func messagePartText(p lipapi.Part) string { - switch p.Kind { - case lipapi.PartText: - return p.Text - case lipapi.PartToolResult, lipapi.PartJSON: - return string(p.Content) - default: - return string(p.Kind) - } -} - -func convertOrphanedToolResult(p lipapi.Part) lipapi.Message { - rendered := string(p.Content) - if len(p.Content) == 0 { - rendered = "" - } - header := "Prior tool output (original tool call reference unavailable)." - if id := strings.TrimSpace(p.ToolCallID); id != "" { - header += " call_id=" + id + "." - } - return lipapi.Message{ - Role: lipapi.RoleSystem, - Parts: []lipapi.Part{lipapi.TextPart(header + "\n" + rendered)}, - } -} - -func buildOpenCodeBridge(hasTools bool) string { - var b strings.Builder - b.WriteString(openCodeBridgeMarker) - b.WriteString(":\n") - if hasTools { - // Keep this guidance generic. OpenCode tool names and schemas vary by - // installation, plugin, and session; the structured tool list is the only - // authoritative source of callable names. Duplicating names in prose makes - // random session-specific tools look universal and can bias the model toward - // tools the current request did not actually expose. - b.WriteString("- Prefer the available client shell tool when command execution is needed.\n") - } else { - b.WriteString("- No callable client tools are available in this request. Do not attempt tool calls; respond in plain text or ask the user/client to provide tools.\n") - } - b.WriteString("- Never emit textual tool-call syntax such as `to=functions.` or JSON tool calls in assistant content; use structured tool calls only when tools are available.\n") - if !hasTools { - // No tools are exposed, so do not append criticalInstruction("OpenCode"): - // it tells the model to use agent-provided tools, contradicting the - // "no callable client tools" guidance above and risking spurious tool calls. - return b.String() - } - b.WriteString( - "- For bash-style tools, arguments MUST be a JSON object with string " + - "`command` and string `description`.\n" + - "- Bash-style tools MAY include numeric `timeout` in milliseconds " + - "and string `workdir` when the client schema exposes them.\n" + - "- Never emit array-valued `command` arguments for shell execution.\n" + - "- Do not use `apply_patch`; use the client's native file editing tools instead.\n" + - "- Do not use `update_plan` or `read_plan`; use the client's task tools instead.\n" + - "- If you need a working directory, prefer `workdir` over `cd` commands " + - "or embedding cwd text in `description`.\n" + - "\n" + - criticalInstruction("OpenCode"), - ) - return b.String() -} - -func buildPiBridge() string { - return piBridgeMarker + ":\n" + - "- Use only tools exposed by the pi client for this session.\n" + - "- Use `bash` for terminal execution with a JSON object containing string `command` and optional numeric `timeout` in seconds; pi has no default timeout.\n" + - "- Do not emit `shell`, `local_shell_call`, or array-valued shell commands; pi expects the `bash` tool name.\n" + - "- Do not use `apply_patch`; use pi's `edit` tool for exact text replacement in a single file.\n" + - "- For `edit`, pass `path` plus an `edits` array of replacements with `oldText` and `newText`, each matched against the original file.\n" + - "- For file reads use `read` with `path` and optional `offset`/`limit`; for full rewrites use `write` with `path` and `content`.\n" + - "- Keep responses concise and show file paths clearly.\n" + - "\n" + - criticalInstruction("pi") -} - -func buildDroidBridge(availableTools []string) string { - native := sortedNativeDroidTools() - available := availableTools - if len(available) == 0 { - available = native - } - availableText := joinBacktickList(available) - nativeText := joinBacktickList(native) - return droidBridgeMarker + ":\n" + - "- This session is using Factory Droid tools, not Codex-native tools.\n" + - "- Use only tool names that are actually available in this session: " + availableText + ".\n" + - "- Prefer the native Factory Droid tool family when available: " + nativeText + ".\n" + - "- Use Droid argument shapes exactly for the native file/execute tools: `Read(file_path, offset?, limit?)`, `LS(directory_path?)`, `Execute(command, timeout?, cwd?)`, `Edit(file_path, old_str, new_str)`, `Grep(pattern, path?, file_pattern?, max_results?)`, `Glob(pattern, max_results?)`, `Create(file_path, content)`.\n" + - "- Do not emit Codex-native tool names such as `read`, `read_file`, `bash`, `shell`, `apply_patch`, `grep_files`, or `list_dir`.\n" + - "- Use `TodoWrite` instead of Codex task-planner tools, `WebSearch` for web search, and `FetchUrl` for direct URL fetches when those tools are available.\n" + - "- Keep tool arguments as JSON objects; for `Execute`, the `command` value must be a single shell command string, not an array.\n" + - "\n" + - criticalInstruction("Droid") -} - func criticalInstruction(agentName string) string { return "CRITICAL INSTRUCTION:\n" + "(a) NEVER run cat inside a bash command to create a file or append to an " + @@ -655,30 +285,6 @@ func criticalInstruction(agentName string) string { "string matching. Use respective tools provided by the " + agentName + " agent instead." } -func buildHermesBridge() string { - return hermesBridgeMarker + ":\n" + - "- Preserve the Hermes Agent identity and system prompt; do not replace or restate it as Codex.\n" + - "- Use structured function/tool calls for every action; never inline textual " + - "`to=functions.` or Harmony-style tool calls in assistant content.\n" + - "- Continue using the available tools until the task is complete and verified.\n" + - "- Perform prerequisite lookup and discovery (files, symbols, context) with tools before acting.\n" + - "- When retrievable context is missing, fetch it with available tools; do not guess or fabricate it.\n" + - "\n" + - "CRITICAL INSTRUCTION:\n" + - "(a) Keep the Hermes identity/system prompt intact; append compatibility guidance, never overwrite it.\n" + - "(b) Never emit textual tool-call syntax (`to=functions.`, Harmony calls) in assistant content; use structured tool calls only." -} - -func applyHermesToolStrict(call *lipapi.Call) { - if call.Extensions == nil { - call.Extensions = map[string]json.RawMessage{} - } - if _, ok := call.Extensions[extCodexToolStrictKey]; ok { - return - } - call.Extensions[extCodexToolStrictKey] = json.RawMessage("false") -} - func applyIgnoreUnsupportedGenParams(call *lipapi.Call) { if call.Extensions == nil { call.Extensions = map[string]json.RawMessage{} @@ -688,23 +294,3 @@ func applyIgnoreUnsupportedGenParams(call *lipapi.Call) { } call.Extensions[extCodexIgnoreUnsupportedGenParamsKey] = json.RawMessage("true") } - -func sortedNativeDroidTools() []string { - out := make([]string, 0, len(droidNativeToolNames)) - for name := range droidNativeToolNames { - out = append(out, name) - } - sort.Strings(out) - return out -} - -func joinBacktickList(names []string) string { - if len(names) == 0 { - return "" - } - parts := make([]string, len(names)) - for i, name := range names { - parts[i] = "`" + name + "`" - } - return strings.Join(parts, ", ") -} diff --git a/internal/plugins/frontends/anthropic/handler.go b/internal/plugins/frontends/anthropic/handler.go index 60e3a4f8..63c47352 100644 --- a/internal/plugins/frontends/anthropic/handler.go +++ b/internal/plugins/frontends/anthropic/handler.go @@ -127,10 +127,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) releaseDecode() if err != nil { - log := h.Log - if log == nil { - log = slog.Default() - } + log := diag.LoggerOrDefault(h.Log) diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON", "invalid_request_error")) diff --git a/internal/plugins/frontends/gemini/handler.go b/internal/plugins/frontends/gemini/handler.go index 70745574..b83df7ce 100644 --- a/internal/plugins/frontends/gemini/handler.go +++ b/internal/plugins/frontends/gemini/handler.go @@ -125,10 +125,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { }) releaseDecode() if err != nil { - log := h.Log - if log == nil { - log = slog.Default() - } + log := diag.LoggerOrDefault(h.Log) diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON")) diff --git a/internal/plugins/frontends/openailegacy/handler.go b/internal/plugins/frontends/openailegacy/handler.go index a4d0f762..acf4d212 100644 --- a/internal/plugins/frontends/openailegacy/handler.go +++ b/internal/plugins/frontends/openailegacy/handler.go @@ -119,10 +119,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { decoded, err := DecodeChatRequest(body, DecodeOptions{RouteSelector: sel, Headers: r.Header}) releaseDecode() if err != nil { - log := h.Log - if log == nil { - log = slog.Default() - } + log := diag.LoggerOrDefault(h.Log) diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON(w, http.StatusBadRequest, "invalid request JSON", "invalid_request_error", "")) diff --git a/internal/plugins/frontends/openairesponses/handler.go b/internal/plugins/frontends/openairesponses/handler.go index f020c734..ac2f9b97 100644 --- a/internal/plugins/frontends/openairesponses/handler.go +++ b/internal/plugins/frontends/openairesponses/handler.go @@ -163,10 +163,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { decoded, err := DecodeCreateRequest(body, DecodeOptions{RouteSelector: sel, Headers: r.Header}) releaseDecode() if err != nil { - log := h.Log - if log == nil { - log = slog.Default() - } + log := diag.LoggerOrDefault(h.Log) diag.LogError(ctx, log, "decode request failed", diag.AttrOpts{}, err, slog.String("detail", diag.TruncErrDetail(err, 512))) streamdebug.LogDecodeFailure(ctx, log, ID, body, err) h.logWriteJSONErr(ctx, "write error json failed", WriteErrorJSON( From 91d7eb54f24c3f8f7879d1c707b801f6a68cb113 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 1 Jul 2026 15:49:04 +0200 Subject: [PATCH 2/2] address CodeRabbit review findings on codex ws/compat - ws_session: guard closeIdle against stale idle-timer callbacks that race with acquire reusing a tracked session (skip only when the session is still the map entry and its timer was stopped since firing); orphaned sessions are still closed to avoid leaking connections. - ws_stream: releaseOnce(true) on mapper errors so the session is evicted even if the caller never calls Close, matching the read-error path. - codexclientcompat: extract parseFunctionCallID shared by isFunctionCallPart and collectKnownToolCallIDs so accepted type variants stay in sync. - codexclientcompat: preserve part order in convertOrphanedToolResults so converted orphaned results keep their original position relative to known results instead of being flushed after them. Adds focused tests for each change. Co-authored-by: Cursor --- .../backends/openaicodex/ws_internal_test.go | 90 ++++++++ .../backends/openaicodex/ws_session.go | 10 + .../plugins/backends/openaicodex/ws_stream.go | 4 + .../codexclientcompat/bridge_opencode.go | 56 ++--- .../bridge_opencode_internal_test.go | 201 ++++++++++++++++++ 5 files changed, 337 insertions(+), 24 deletions(-) create mode 100644 internal/plugins/features/codexclientcompat/bridge_opencode_internal_test.go diff --git a/internal/plugins/backends/openaicodex/ws_internal_test.go b/internal/plugins/backends/openaicodex/ws_internal_test.go index 5ad59f56..4cf0f97c 100644 --- a/internal/plugins/backends/openaicodex/ws_internal_test.go +++ b/internal/plugins/backends/openaicodex/ws_internal_test.go @@ -364,3 +364,93 @@ func newTestWebSocketPair(t *testing.T) (*gorillawebsocket.Conn, *gorillawebsock return nil, nil } } + +func TestWSSessionStoreCloseIdleClosesTrackedArmedSession(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + + store := newWSSessionStore() + key := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "sess"} + timer := time.AfterFunc(time.Hour, func() {}) + defer timer.Stop() + session := &wsSessionConn{ + key: key, + store: store, + sem: make(chan struct{}, 1), + conn: clientConn, + idleTimer: timer, + } + store.sessions[key] = session + + store.closeIdle(key, session) + + if session.conn != nil { + t.Fatal("tracked armed idle session conn was not closed") + } + if _, ok := store.sessions[key]; ok { + t.Fatal("tracked armed idle session was not forgotten") + } + if session.idleTimer != nil { + t.Fatal("tracked armed idle timer was not stopped") + } +} + +func TestWSSessionStoreCloseIdleSkipsTrackedInUseSession(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + + store := newWSSessionStore() + key := wsSessionKey{baseURL: "ws://example.test", accessToken: "tok", conversation: "sess"} + // idleTimer == nil models a concurrent acquire having stopped the timer to + // reuse this tracked session while the stale closeIdle callback races. The + // session is still the map entry, so the guard must skip closing it. + session := &wsSessionConn{ + key: key, + store: store, + sem: make(chan struct{}, 1), + conn: clientConn, + idleTimer: nil, + } + store.sessions[key] = session + + store.closeIdle(key, session) + + if session.conn == nil { + t.Fatal("tracked in-use session conn was closed by stale closeIdle callback") + } + if _, ok := store.sessions[key]; !ok { + t.Fatal("tracked in-use session was forgotten by stale closeIdle callback") + } +} + +func TestWSStreamReleasesSessionOnMapperError(t *testing.T) { + t.Parallel() + clientConn, serverConn := newTestWebSocketPair(t) + defer func() { _ = serverConn.Close() }() + + // A non-JSON text frame causes codexEventMapper.handleData to return a + // malformed-stream-event error on the mapper-error path. + if err := serverConn.WriteMessage(gorillawebsocket.TextMessage, []byte("not-json")); err != nil { + t.Fatal(err) + } + + var released bool + var closeConn bool + stream := newWSStream(clientConn, 0) + stream.release = func(close bool) { + released = true + closeConn = close + } + + if _, err := stream.Recv(context.Background()); err == nil { + t.Fatal("expected Recv error for malformed frame") + } + if !released { + t.Fatal("expected session released on mapper error") + } + if !closeConn { + t.Fatal("expected release with closeConn=true on mapper error") + } +} diff --git a/internal/plugins/backends/openaicodex/ws_session.go b/internal/plugins/backends/openaicodex/ws_session.go index 99a626a1..fad0036e 100644 --- a/internal/plugins/backends/openaicodex/ws_session.go +++ b/internal/plugins/backends/openaicodex/ws_session.go @@ -128,6 +128,16 @@ func (s *wsSessionStore) closeIdle(key wsSessionKey, session *wsSessionConn) { defer session.unlock() s.mu.Lock() defer s.mu.Unlock() + // The idle timer may have fired after acquire already stopped it (Stop returns + // false for an already-expired timer) and is about to reuse this tracked + // session. Closing in that case would evict a session an in-flight acquire is + // reusing and force a redundant dial on an untracked session. Skip only when + // the session is still the tracked entry AND its timer was stopped since + // firing. A session that is no longer the map entry (replaced/forgotten) still + // needs its conn closed to avoid leaking an orphaned connection. + if s.sessions[key] == session && session.idleTimer == nil { + return + } session.closeConnLocked() session.stopIdleTimerLocked() s.forgetLocked(key, session) diff --git a/internal/plugins/backends/openaicodex/ws_stream.go b/internal/plugins/backends/openaicodex/ws_stream.go index f34006b4..ebfeb473 100644 --- a/internal/plugins/backends/openaicodex/ws_stream.go +++ b/internal/plugins/backends/openaicodex/ws_stream.go @@ -92,6 +92,10 @@ func (s *wsStream) Recv(ctx context.Context) (lipapi.Event, error) { } if err := s.mapper.handleData(text); err != nil { s.mu.Unlock() + // A malformed frame desyncs the stream; release with a close so the + // session is evicted even if the caller never calls Close, mirroring the + // read-error path. releaseOnce is idempotent, so a later Close is safe. + s.releaseOnce(true) return lipapi.Event{}, err } s.mu.Unlock() diff --git a/internal/plugins/features/codexclientcompat/bridge_opencode.go b/internal/plugins/features/codexclientcompat/bridge_opencode.go index 8d72bb2e..4ef1c2d0 100644 --- a/internal/plugins/features/codexclientcompat/bridge_opencode.go +++ b/internal/plugins/features/codexclientcompat/bridge_opencode.go @@ -59,9 +59,14 @@ func hasStructuredToolTranscript(msgs []lipapi.Message) bool { return false } -func isFunctionCallPart(p lipapi.Part) bool { +// parseFunctionCallID extracts the call id and tool name from an assistant +// function-call part. It accepts both Responses-style ("function_call") and +// Chat Completions-style ("function") type tags so the accepted variants stay +// in sync across callers. ok is true when the part is a recognized function-call +// part; id and name are trimmed and may be empty. +func parseFunctionCallID(p lipapi.Part) (id, name string, ok bool) { if len(p.Content) == 0 { - return false + return "", "", false } var fc struct { Type string `json:"type"` @@ -73,17 +78,22 @@ func isFunctionCallPart(p lipapi.Part) bool { } `json:"function"` } if json.Unmarshal(p.Content, &fc) != nil { - return false + return "", "", false } if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { - return false + return "", "", false } - id := firstNonEmpty(fc.CallID, fc.ID) - name := strings.TrimSpace(fc.Name) + id = strings.TrimSpace(firstNonEmpty(fc.CallID, fc.ID)) + name = strings.TrimSpace(fc.Name) if name == "" && fc.Function != nil { name = strings.TrimSpace(fc.Function.Name) } - return strings.TrimSpace(id) != "" && name != "" + return id, name, true +} + +func isFunctionCallPart(p lipapi.Part) bool { + id, name, ok := parseFunctionCallID(p) + return ok && id != "" && name != "" } func convertOrphanedToolResults(call *lipapi.Call) { @@ -94,7 +104,17 @@ func convertOrphanedToolResults(call *lipapi.Call) { out = append(out, m) continue } - kept := make([]lipapi.Part, 0, len(m.Parts)) + // Flush kept tool parts in place so converted orphaned results are emitted + // in the same relative order they appear in the original message. Buffering + // kept parts until the loop ends would move orphaned (now System) results + // ahead of earlier known results and misrepresent tool/result ordering. + var kept []lipapi.Part + flushKept := func() { + if len(kept) > 0 { + out = append(out, lipapi.Message{Role: lipapi.RoleTool, Parts: kept}) + kept = nil + } + } for _, p := range m.Parts { if p.Kind != lipapi.PartToolResult { kept = append(kept, p) @@ -107,11 +127,10 @@ func convertOrphanedToolResults(call *lipapi.Call) { continue } } + flushKept() out = append(out, convertOrphanedToolResult(p)) } - if len(kept) > 0 { - out = append(out, lipapi.Message{Role: lipapi.RoleTool, Parts: kept}) - } + flushKept() } call.Messages = out } @@ -126,23 +145,12 @@ func collectKnownToolCallIDs(msgs []lipapi.Message) map[string]struct{} { if p.Kind != lipapi.PartJSON { continue } - var fc struct { - Type string `json:"type"` - CallID string `json:"call_id"` - ID string `json:"id"` - } - if json.Unmarshal(p.Content, &fc) != nil { - continue - } // Accept Responses-style ("function_call") and Chat Completions-style // ("function") assistant tool calls so matching tool results are preserved. - if !strings.EqualFold(fc.Type, "function_call") && !strings.EqualFold(fc.Type, "function") { + id, _, ok := parseFunctionCallID(p) + if !ok { continue } - id := strings.TrimSpace(fc.CallID) - if id == "" { - id = strings.TrimSpace(fc.ID) - } if id != "" { known[id] = struct{}{} } diff --git a/internal/plugins/features/codexclientcompat/bridge_opencode_internal_test.go b/internal/plugins/features/codexclientcompat/bridge_opencode_internal_test.go new file mode 100644 index 00000000..8ef934f8 --- /dev/null +++ b/internal/plugins/features/codexclientcompat/bridge_opencode_internal_test.go @@ -0,0 +1,201 @@ +package codexclientcompat + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/matdev83/go-llm-interactive-proxy/pkg/lipapi" +) + +func TestParseFunctionCallID(t *testing.T) { + t.Parallel() + cases := []struct { + name string + part lipapi.Part + wantID string + wantName string + wantOK bool + }{ + { + name: "responses style with call_id and name", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"call_1","name":"toolA"}`)}, + wantID: "call_1", + wantName: "toolA", + wantOK: true, + }, + { + name: "chat completions style with id and function.name", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function","id":"call_2","function":{"name":"toolB"}}`)}, + wantID: "call_2", + wantName: "toolB", + wantOK: true, + }, + { + name: "responses style missing name", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"call_3"}`)}, + wantID: "call_3", + wantName: "", + wantOK: true, + }, + { + name: "whitespace call id is trimmed", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":" call_4 ","name":"toolD"}`)}, + wantID: "call_4", + wantName: "toolD", + wantOK: true, + }, + { + name: "non-function type rejected", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"text"}`)}, + wantOK: false, + }, + { + name: "empty content rejected", + part: lipapi.Part{Kind: lipapi.PartJSON}, + wantOK: false, + }, + { + name: "invalid json rejected", + part: lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`not-json`)}, + wantOK: false, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + id, name, ok := parseFunctionCallID(tc.part) + if ok != tc.wantOK { + t.Fatalf("ok = %v, want %v (id=%q name=%q)", ok, tc.wantOK, id, name) + } + if id != tc.wantID { + t.Fatalf("id = %q, want %q", id, tc.wantID) + } + if name != tc.wantName { + t.Fatalf("name = %q, want %q", name, tc.wantName) + } + }) + } +} + +func TestIsFunctionCallPartRequiresIDAndName(t *testing.T) { + t.Parallel() + cases := []struct { + name string + part lipapi.Part + want bool + }{ + {"both present", lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"call_1","name":"toolA"}`)}, true}, + {"chat style both present", lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function","id":"call_2","function":{"name":"toolB"}}`)}, true}, + {"missing name", lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"call_3"}`)}, false}, + {"missing id", lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","name":"toolA"}`)}, false}, + {"non-function type", lipapi.Part{Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"text"}`)}, false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := isFunctionCallPart(tc.part); got != tc.want { + t.Fatalf("isFunctionCallPart = %v, want %v", got, tc.want) + } + }) + } +} + +func TestCollectKnownToolCallIDsAcceptsBothStyles(t *testing.T) { + t.Parallel() + msgs := []lipapi.Message{ + { + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{ + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"resp_id","name":"toolA"}`)}, + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function","id":"chat_id","function":{"name":"toolB"}}`)}, + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","name":"noID"}`)}, + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"text","text":"hi"}`)}, + }, + }, + } + known := collectKnownToolCallIDs(msgs) + if _, ok := known["resp_id"]; !ok { + t.Errorf("missing responses-style id %q: %v", "resp_id", known) + } + if _, ok := known["chat_id"]; !ok { + t.Errorf("missing chat-completions-style id %q: %v", "chat_id", known) + } + if len(known) != 2 { + t.Errorf("known = %v, want exactly 2 entries", known) + } +} + +func TestConvertOrphanedToolResultsPreservesPartOrder(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Messages: []lipapi.Message{ + { + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{ + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"a","name":"toolA"}`)}, + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"c","name":"toolC"}`)}, + }, + }, + { + Role: lipapi.RoleTool, + Parts: []lipapi.Part{ + {Kind: lipapi.PartToolResult, ToolCallID: "a", Content: json.RawMessage(`{"out":"A"}`)}, + {Kind: lipapi.PartToolResult, ToolCallID: "orphan", Content: json.RawMessage(`{"out":"B"}`)}, + {Kind: lipapi.PartToolResult, ToolCallID: "c", Content: json.RawMessage(`{"out":"C"}`)}, + }, + }, + }, + } + + convertOrphanedToolResults(call) + + // Expected order preserves the original A -> B -> C sequence: the known + // results bracket the converted orphan instead of all known results being + // flushed after it. + wantRoles := []lipapi.Role{lipapi.RoleAssistant, lipapi.RoleTool, lipapi.RoleSystem, lipapi.RoleTool} + if len(call.Messages) != len(wantRoles) { + t.Fatalf("messages = %d, want %d: %#v", len(call.Messages), len(wantRoles), call.Messages) + } + for i, want := range wantRoles { + if call.Messages[i].Role != want { + t.Fatalf("messages[%d].Role = %q, want %q", i, call.Messages[i].Role, want) + } + } + if got := call.Messages[1].Parts[0].ToolCallID; got != "a" { + t.Fatalf("messages[1].Parts[0].ToolCallID = %q, want %q", got, "a") + } + if !strings.Contains(messageText(call.Messages[2]), "Prior tool output") { + t.Fatalf("messages[2] = %#v, want System orphan conversion", call.Messages[2]) + } + if got := call.Messages[3].Parts[0].ToolCallID; got != "c" { + t.Fatalf("messages[3].Parts[0].ToolCallID = %q, want %q", got, "c") + } +} + +func TestConvertOrphanedToolResultsLeavesKnownOnlyMessagesUntouched(t *testing.T) { + t.Parallel() + call := &lipapi.Call{ + Messages: []lipapi.Message{ + { + Role: lipapi.RoleAssistant, + Parts: []lipapi.Part{ + {Kind: lipapi.PartJSON, Content: json.RawMessage(`{"type":"function_call","call_id":"a","name":"toolA"}`)}, + }, + }, + { + Role: lipapi.RoleTool, + Parts: []lipapi.Part{ + {Kind: lipapi.PartToolResult, ToolCallID: "a", Content: json.RawMessage(`{"out":"A"}`)}, + }, + }, + }, + } + convertOrphanedToolResults(call) + if len(call.Messages) != 2 { + t.Fatalf("messages = %d, want 2: %#v", len(call.Messages), call.Messages) + } + if call.Messages[1].Role != lipapi.RoleTool || len(call.Messages[1].Parts) != 1 { + t.Fatalf("messages[1] = %#v, want single Tool([A])", call.Messages[1]) + } +}