diff --git a/Cargo.lock b/Cargo.lock index fcb0d8721..c936b824e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2630,6 +2630,7 @@ dependencies = [ "globset", "ignore", "lopdf", + "objc2-app-kit 0.3.2", "portable-pty", "prost", "quick-xml 0.38.4", diff --git a/crates/agent-gateway/internal/handler/helpers.go b/crates/agent-gateway/internal/handler/helpers.go index 0221c6797..41edec9f4 100644 --- a/crates/agent-gateway/internal/handler/helpers.go +++ b/crates/agent-gateway/internal/handler/helpers.go @@ -53,12 +53,12 @@ func waitForEnvelope( } } -func gatewayErrorStatus(errResp *gatewayv1.ErrorResponse) int { +func GatewayErrorStatus(errResp *gatewayv1.ErrorResponse) int { if errResp == nil { return http.StatusBadGateway } - switch errResp.GetCode() { - case http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound, http.StatusConflict: + switch int(errResp.GetCode()) { + case http.StatusBadRequest, http.StatusUnauthorized, http.StatusForbidden, http.StatusNotFound, http.StatusConflict: return int(errResp.GetCode()) default: return http.StatusBadGateway diff --git a/crates/agent-gateway/internal/handler/helpers_test.go b/crates/agent-gateway/internal/handler/helpers_test.go new file mode 100644 index 000000000..49d07b452 --- /dev/null +++ b/crates/agent-gateway/internal/handler/helpers_test.go @@ -0,0 +1,29 @@ +package handler + +import ( + "net/http" + "testing" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func TestGatewayErrorStatusPassesExpectedClientErrors(t *testing.T) { + t.Parallel() + + cases := map[int32]int{ + http.StatusBadRequest: http.StatusBadRequest, + http.StatusUnauthorized: http.StatusUnauthorized, + http.StatusForbidden: http.StatusForbidden, + http.StatusNotFound: http.StatusNotFound, + http.StatusConflict: http.StatusConflict, + http.StatusTeapot: http.StatusBadGateway, + 0: http.StatusBadGateway, + } + + for code, want := range cases { + got := GatewayErrorStatus(&gatewayv1.ErrorResponse{Code: code}) + if got != want { + t.Fatalf("GatewayErrorStatus(%d) = %d, want %d", code, got, want) + } + } +} diff --git a/crates/agent-gateway/internal/handler/upload.go b/crates/agent-gateway/internal/handler/upload.go index c286bad03..a52ab8e5e 100644 --- a/crates/agent-gateway/internal/handler/upload.go +++ b/crates/agent-gateway/internal/handler/upload.go @@ -107,7 +107,7 @@ func ImportReadableFiles( return } if errResp := env.GetError(); errResp != nil { - writeError(w, gatewayErrorStatus(errResp), errResp.GetMessage()) + writeError(w, GatewayErrorStatus(errResp), errResp.GetMessage()) return } diff --git a/crates/agent-gateway/internal/server/http.go b/crates/agent-gateway/internal/server/http.go index 26075524b..89e96eebb 100644 --- a/crates/agent-gateway/internal/server/http.go +++ b/crates/agent-gateway/internal/server/http.go @@ -106,11 +106,7 @@ func publicHistoryShare(cfg *config.Config, sm *session.Manager) http.HandlerFun return } if errResp := response.GetError(); errResp != nil { - status := http.StatusInternalServerError - if isPublicHistoryShareNotFound(errResp.GetMessage()) { - status = http.StatusNotFound - } - writePublicHistoryShareError(w, status, errResp.GetMessage()) + writePublicHistoryShareError(w, handler.GatewayErrorStatus(errResp), errResp.GetMessage()) return } @@ -130,13 +126,6 @@ func publicHistoryShare(cfg *config.Config, sm *session.Manager) http.HandlerFun } } -func isPublicHistoryShareNotFound(message string) bool { - normalized := strings.TrimSpace(message) - return strings.Contains(normalized, "分享链接不存在或已关闭") || - strings.Contains(normalized, "分享 token 不能为空") || - strings.Contains(normalized, "未找到对应的历史对话") -} - func writePublicHistoryShareError(w http.ResponseWriter, status int, message string) { writeJSON(w, status, map[string]any{ "error": strings.TrimSpace(message), diff --git a/crates/agent-gateway/internal/server/http_test.go b/crates/agent-gateway/internal/server/http_test.go index 2098187e1..2f780b98d 100644 --- a/crates/agent-gateway/internal/server/http_test.go +++ b/crates/agent-gateway/internal/server/http_test.go @@ -122,6 +122,29 @@ func TestPublicHistoryShareResolvesWithoutAuthorization(t *testing.T) { } func TestPublicHistoryShareReturnsNotFoundForDisabledToken(t *testing.T) { + status := publicHistoryShareErrorStatusForTest(t, http.StatusNotFound, "分享链接不存在或已关闭") + if status != http.StatusNotFound { + t.Fatalf("expected status %d, got %d", http.StatusNotFound, status) + } +} + +func TestPublicHistoryShareReturnsBadRequestFromAgentCode(t *testing.T) { + status := publicHistoryShareErrorStatusForTest(t, http.StatusBadRequest, "分享 token 不能为空") + if status != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, status) + } +} + +func TestPublicHistoryShareDoesNotInferStatusFromLegacyMessage(t *testing.T) { + status := publicHistoryShareErrorStatusForTest(t, http.StatusInternalServerError, "分享链接不存在或已关闭") + if status != http.StatusBadGateway { + t.Fatalf("expected status %d, got %d", http.StatusBadGateway, status) + } +} + +func publicHistoryShareErrorStatusForTest(t *testing.T, code int, message string) int { + t.Helper() + sm := session.NewManager() sm.RecordAuthentication("desktop-agent", "0.9.0", "session-1") agentSession := session.NewAgentSession(sm.LatestAuthSnapshot()) @@ -151,8 +174,8 @@ func TestPublicHistoryShareReturnsNotFoundForDisabledToken(t *testing.T) { Timestamp: time.Now().Unix(), Payload: &gatewayv1.AgentEnvelope_Error{ Error: &gatewayv1.ErrorResponse{ - Code: http.StatusNotFound, - Message: "分享链接不存在或已关闭", + Code: int32(code), + Message: message, }, }, }) @@ -162,9 +185,7 @@ func TestPublicHistoryShareReturnsNotFoundForDisabledToken(t *testing.T) { case <-time.After(time.Second): t.Fatal("timed out waiting for public share response") } - if rec.Code != http.StatusNotFound { - t.Fatalf("expected status %d, got %d body %s", http.StatusNotFound, rec.Code, rec.Body.String()) - } + return rec.Code } func TestPublicHistoryShareReturnsUnavailableWhenAgentOffline(t *testing.T) { diff --git a/crates/agent-gateway/internal/server/websocket.go b/crates/agent-gateway/internal/server/websocket.go index 1d2466be8..56f87c440 100644 --- a/crates/agent-gateway/internal/server/websocket.go +++ b/crates/agent-gateway/internal/server/websocket.go @@ -13,7 +13,6 @@ import ( "golang.org/x/net/websocket" "github.com/liveagent/agent-gateway/internal/config" - "github.com/liveagent/agent-gateway/internal/handler" gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" "github.com/liveagent/agent-gateway/internal/session" ) @@ -79,7 +78,7 @@ type websocketConnection struct { conn *websocket.Conn - writeMu sync.Mutex + writer *websocketConnectionWriter closeOnce sync.Once done chan struct{} authorized bool @@ -94,16 +93,8 @@ type websocketConnection struct { chatEventsCleanup func() heartbeatOnce sync.Once - activeChatsMu sync.RWMutex - activeChats map[string]*websocketChatState - recentChats map[string]time.Time - - activeChatAttachmentsMu sync.Mutex - activeChatAttachments map[string]context.CancelFunc - - terminalInterestMu sync.RWMutex - terminalProjectSubscriptions map[string]struct{} - terminalSessionSubscriptions map[string]struct{} + chatTracker *websocketChatTracker + terminalInterest *websocketTerminalInterestTracker } const recentActiveChatRetention = 5 * time.Second @@ -118,15 +109,13 @@ func NewWebSocketServer(cfg *config.Config, sm *session.Manager) http.Handler { }, Handler: websocket.Handler(func(conn *websocket.Conn) { state := &websocketConnection{ - cfg: cfg, - sm: sm, - conn: conn, - done: make(chan struct{}), - activeChats: make(map[string]*websocketChatState), - recentChats: make(map[string]time.Time), - activeChatAttachments: make(map[string]context.CancelFunc), - terminalProjectSubscriptions: make(map[string]struct{}), - terminalSessionSubscriptions: make(map[string]struct{}), + cfg: cfg, + sm: sm, + conn: conn, + writer: newWebsocketConnectionWriter(conn, cfg.WebSocketWriteTimeout), + done: make(chan struct{}), + chatTracker: newWebsocketChatTracker(), + terminalInterest: newWebsocketTerminalInterestTracker(), } defer state.close() state.serve() @@ -371,61 +360,19 @@ func (c *websocketConnection) replayTerminalSessionSnapshot() { } func (c *websocketConnection) rememberTerminalProject(projectPathKey string) { - projectPathKey = strings.TrimSpace(projectPathKey) - if projectPathKey == "" { - return - } - c.terminalInterestMu.Lock() - c.terminalProjectSubscriptions[projectPathKey] = struct{}{} - c.terminalInterestMu.Unlock() + c.terminalInterest.rememberProject(projectPathKey) } func (c *websocketConnection) rememberTerminalSession(sessionID string, projectPathKey string) { - sessionID = strings.TrimSpace(sessionID) - projectPathKey = strings.TrimSpace(projectPathKey) - if sessionID == "" && projectPathKey == "" { - return - } - c.terminalInterestMu.Lock() - if sessionID != "" { - c.terminalSessionSubscriptions[sessionID] = struct{}{} - } - if projectPathKey != "" { - c.terminalProjectSubscriptions[projectPathKey] = struct{}{} - } - c.terminalInterestMu.Unlock() + c.terminalInterest.rememberSession(sessionID, projectPathKey) } func (c *websocketConnection) forgetTerminalInterest(sessionID string, projectPathKey string) { - sessionID = strings.TrimSpace(sessionID) - projectPathKey = strings.TrimSpace(projectPathKey) - c.terminalInterestMu.Lock() - if sessionID != "" { - delete(c.terminalSessionSubscriptions, sessionID) - } - if sessionID == "" && projectPathKey != "" { - delete(c.terminalProjectSubscriptions, projectPathKey) - } - c.terminalInterestMu.Unlock() + c.terminalInterest.forget(sessionID, projectPathKey) } func (c *websocketConnection) shouldForwardTerminalEvent(event *gatewayv1.TerminalEvent) bool { - if event == nil { - return false - } - sessionID := strings.TrimSpace(event.GetSessionId()) - projectPathKey := strings.TrimSpace(event.GetProjectPathKey()) - kind := strings.TrimSpace(event.GetKind()) - - if kind != "output" { - return sessionID != "" || projectPathKey != "" - } - - c.terminalInterestMu.RLock() - _, sessionSubscribed := c.terminalSessionSubscriptions[sessionID] - c.terminalInterestMu.RUnlock() - - return sessionID != "" && sessionSubscribed + return c.terminalInterest.shouldForward(event) } func (c *websocketConnection) startWebSocketHeartbeat() { @@ -458,2791 +405,84 @@ func (c *websocketConnection) startWebSocketHeartbeat() { } func (c *websocketConnection) dispatch(req websocketRequest) { - switch req.Type { - case "status.get": - _ = c.writeResponse(req.ID, c.sm.Status()) - case "fs.roots": - c.handleFsRoots(req) - case "fs.list_dirs": - c.handleFsListDirs(req) - case "fs.create_project_folder": - c.handleFsCreateProjectFolder(req) - case "fs.list": - c.handleFsList(req) - case "fs.write_text": - c.handleFsWriteText(req) - case "fs.create_dir": - c.handleFsCreateDir(req) - case "fs.rename": - c.handleFsRename(req) - case "fs.delete": - c.handleFsDelete(req) - case "history.list": - c.handleHistoryList(req) - case "history.workdirs": - c.handleHistoryWorkdirs(req) - case "history.shared_list": - c.handleHistorySharedList(req) - case "history.get": - c.handleHistoryGet(req) - case "history.rename": - c.handleHistoryRename(req) - case "history.pin": - c.handleHistoryPin(req) - case "history.share.get": - c.handleHistoryShareGet(req) - case "history.share.set": - c.handleHistoryShareSet(req) - case "history.delete": - c.handleHistoryDelete(req) - case "history.truncate": - c.handleHistoryTruncate(req) - case "providers.list": - c.handleProviderList(req) - case "settings.get": - c.handleSettingsGet(req) - case "settings.update": - c.handleSettingsUpdate(req) - case "skills.list": - c.handleSkillFilesList(req) - case "mentions.list": - c.handleFileMentionList(req) - case "skills.read-metadata": - c.handleSkillMetadataRead(req) - case "skills.read-text": - c.handleSkillTextRead(req) - case "skills.manage": - c.handleSkillManage(req) - case "chat.start": - c.handleChatStart(req) - case "chat.resume": - c.handleChatResume(req) - case "chat.attach": - c.handleChatAttach(req) - case "chat.detach": - c.handleChatDetach(req) - case "chat.cancel": - c.handleChatCancel(req) - case "files.preview": - c.handleUploadedImagePreview(req) - case "memory.manage": - c.handleMemoryManage(req) - case "terminal.shell_options", "terminal.list", "terminal.create", "terminal.attach", "terminal.input", "terminal.resize", "terminal.rename", "terminal.close", "terminal.close_project": - c.handleTerminalRequest(req) - case "terminal.detach": - c.handleTerminalDetach(req) - case "git.status", "git.branches", "git.init", "git.switch_branch", "git.create_branch", "git.diff", "git.log", "git.commit_details", "git.compare_commit_with_remote", "git.commit_diff", "git.stage", "git.stage_all", "git.unstage", "git.unstage_all", "git.discard", "git.discard_all", "git.add_to_gitignore", "git.commit", "git.fetch", "git.pull", "git.set_remote", "git.push": - c.handleGitRequest(req) - case "cron.manage": - c.handleCronManage(req) - case "provider.models": - c.handleProviderModels(req) - default: + handler := websocketRequestHandlers[req.Type] + if handler == nil { _ = c.writeError(req.ID, "unsupported request type") - } -} - -func (c *websocketConnection) handleFsRoots(req websocketRequest) { - // Payload is intentionally empty; we still decode to reject unexpected fields. - var body struct{} - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.roots payload") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsRoots{ - FsRoots: &gatewayv1.FsRootsRequest{}, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsRootsResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") return } - - rootPayload := make([]map[string]any, 0, len(resp.GetRoots())) - for _, root := range resp.GetRoots() { - rootPayload = append(rootPayload, map[string]any{ - "id": root.GetId(), - "path": root.GetPath(), - "kind": root.GetKind(), - "label": root.GetLabel(), - }) - } - - _ = c.writeResponse(req.ID, map[string]any{ - "roots": rootPayload, - }) + handler(c, req) } -func (c *websocketConnection) handleFsListDirs(req websocketRequest) { - type payload struct { - Path string `json:"path"` - MaxResults *int `json:"max_results"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.list_dirs payload") - return - } - - dir := strings.TrimSpace(body.Path) - if dir == "" { - _ = c.writeError(req.ID, "path is required") - return - } - - maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsListDirs{ - FsListDirs: &gatewayv1.FsListDirsRequest{ - Path: dir, - MaxResults: maxResults, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsListDirsResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } +func (c *websocketConnection) awaitAgentResponse( + requestID string, + envelope *gatewayv1.GatewayEnvelope, +) (*gatewayv1.AgentEnvelope, error) { + ctx, cancel := context.WithTimeout(context.Background(), c.cfg.RequestTimeout) + defer cancel() - entryPayload := make([]map[string]any, 0, len(resp.GetEntries())) - for _, entry := range resp.GetEntries() { - entryPayload = append(entryPayload, map[string]any{ - "path": entry.GetPath(), - "name": entry.GetName(), - }) - } + go func() { + select { + case <-c.done: + cancel() + case <-ctx.Done(): + } + }() - _ = c.writeResponse(req.ID, map[string]any{ - "path": strings.TrimSpace(resp.GetPath()), - "entries": entryPayload, - "truncated": resp.GetTruncated(), - }) + return awaitAgentUnaryResponse(ctx, c.sm, requestID, envelope) } -func (c *websocketConnection) handleFsCreateProjectFolder(req websocketRequest) { - type payload struct { - Parent string `json:"parent"` - Name string `json:"name"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.create_project_folder payload") - return - } - - parent := strings.TrimSpace(body.Parent) - name := strings.TrimSpace(body.Name) - if parent == "" { - _ = c.writeError(req.ID, "parent is required") - return - } - if name == "" { - _ = c.writeError(req.ID, "name is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsCreateProjectFolder{ - FsCreateProjectFolder: &gatewayv1.FsCreateProjectFolderRequest{ - Parent: parent, - Name: name, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsCreateProjectFolderResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "path": strings.TrimSpace(resp.GetPath()), +func (c *websocketConnection) writeResponse(requestID string, payload any) error { + return c.writeEnvelope(websocketEnvelope{ + ID: requestID, + Type: "response", + Payload: payload, }) } -func (c *websocketConnection) handleFsList(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - Path string `json:"path"` - Depth *int `json:"depth"` - Offset *int `json:"offset"` - MaxResults *int `json:"max_results"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.list payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - - depth, err := websocketOptionalUint32(body.Depth, "depth") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - offset, err := websocketOptionalUint32(body.Offset, "offset") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsList{ - FsList: &gatewayv1.FsListRequest{ - Workdir: workdir, - Path: strings.TrimSpace(body.Path), - Depth: depth, - Offset: offset, - MaxResults: maxResults, - }, - }, +func (c *websocketConnection) writeError(requestID string, message string) error { + return c.writeEnvelope(websocketEnvelope{ + ID: requestID, + Type: "error", + Error: message, }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsListResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketFsListResponsePayload(resp)) } -func (c *websocketConnection) handleFsWriteText(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - Path string `json:"path"` - Content string `json:"content"` - Mode string `json:"mode"` - ExpectedMtimeMs *uint64 `json:"expected_mtime_ms"` - ExpectedContentHash *string `json:"expected_content_hash"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.write_text payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - path := strings.TrimSpace(body.Path) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - if path == "" { - _ = c.writeError(req.ID, "path is required") - return - } - mode := strings.TrimSpace(body.Mode) - if mode == "" { - mode = "rewrite" - } - expectedHash := "" - hasExpectedHash := false - if body.ExpectedContentHash != nil { - expectedHash = strings.TrimSpace(*body.ExpectedContentHash) - hasExpectedHash = true - } - expectedMtime := uint64(0) - hasExpectedMtime := false - if body.ExpectedMtimeMs != nil { - expectedMtime = *body.ExpectedMtimeMs - hasExpectedMtime = true - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsWriteText{ - FsWriteText: &gatewayv1.FsWriteTextRequest{ - Workdir: workdir, - Path: path, - Content: body.Content, - Mode: mode, - ExpectedMtimeMs: expectedMtime, - ExpectedContentHash: expectedHash, - HasExpectedMtimeMs: hasExpectedMtime, - HasExpectedContentHash: hasExpectedHash, - }, - }, +func (c *websocketConnection) writeChatEvent(requestID string, payload any) error { + return c.writeEnvelope(websocketEnvelope{ + ID: requestID, + Type: "chat.event", + Payload: payload, }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsWriteTextResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketFsWriteTextResponsePayload(resp)) } -func (c *websocketConnection) handleFsCreateDir(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - Path string `json:"path"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.create_dir payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - path := strings.TrimSpace(body.Path) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - if path == "" { - _ = c.writeError(req.ID, "path is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsCreateDir{ - FsCreateDir: &gatewayv1.FsCreateDirRequest{ - Workdir: workdir, - Path: path, - }, - }, +func (c *websocketConnection) writeHistoryEvent(payload any) error { + return c.writeEnvelope(websocketEnvelope{ + Type: "history.event", + Payload: payload, }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsCreateDirResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketFsCreateDirResponsePayload(resp)) } -func (c *websocketConnection) handleFsRename(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - FromPath string `json:"from_path"` - ToPath string `json:"to_path"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.rename payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - fromPath := strings.TrimSpace(body.FromPath) - toPath := strings.TrimSpace(body.ToPath) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - if fromPath == "" { - _ = c.writeError(req.ID, "from_path is required") - return - } - if toPath == "" { - _ = c.writeError(req.ID, "to_path is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsRename{ - FsRename: &gatewayv1.FsRenameRequest{ - Workdir: workdir, - FromPath: fromPath, - ToPath: toPath, - }, - }, +func (c *websocketConnection) writeConversationEvent(payload any) error { + return c.writeEnvelope(websocketEnvelope{ + Type: "conversation.event", + Payload: payload, }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsRenameResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketFsRenameResponsePayload(resp)) } -func (c *websocketConnection) handleFsDelete(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - Path string `json:"path"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid fs.delete payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - path := strings.TrimSpace(body.Path) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - if path == "" { - _ = c.writeError(req.ID, "path is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FsDelete{ - FsDelete: &gatewayv1.FsDeleteRequest{ - Workdir: workdir, - Path: path, - }, - }, +func (c *websocketConnection) writeSettingsEvent(payload any) error { + return c.writeEnvelope(websocketEnvelope{ + Type: "settings.event", + Payload: payload, }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFsDeleteResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketFsDeleteResponsePayload(resp)) } -func (c *websocketConnection) handleHistoryList(req websocketRequest) { - type payload struct { - Page int `json:"page"` - PageSize int `json:"page_size"` - Cwd string `json:"cwd"` - CwdEmpty bool `json:"cwd_empty"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.list payload") - return - } - page := body.Page - if page <= 0 { - page = defaultHistoryListPage - } - pageSize := body.PageSize - if pageSize <= 0 { - pageSize = defaultHistoryListPageSize - } else if pageSize > maxHistoryListLimit { - pageSize = maxHistoryListLimit - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryList{ - HistoryList: &gatewayv1.HistoryListRequest{ - Page: int32(page), - PageSize: int32(pageSize), - Cwd: strings.TrimSpace(body.Cwd), - CwdEmpty: body.CwdEmpty, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryListResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - conversations := make([]map[string]any, 0, len(resp.GetConversations())) - for _, conversation := range resp.GetConversations() { - conversations = append(conversations, websocketConversationSummaryPayload(conversation)) - } - - _ = c.writeResponse(req.ID, map[string]any{ - "conversations": conversations, - "total_count": resp.GetTotalCount(), - "running_conversation_ids": c.sm.ActiveChatRunConversationIDs(), - "running_conversations": websocketActiveChatRunSummariesPayload(c.sm.ActiveChatRunSummaries()), - }) -} - -func (c *websocketConnection) handleHistoryWorkdirs(req websocketRequest) { - var body struct{} - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.workdirs payload") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryWorkdirs{ - HistoryWorkdirs: &gatewayv1.HistoryWorkdirsRequest{}, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryWorkdirsResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - workdirs := make([]map[string]any, 0, len(resp.GetWorkdirs())) - for _, workdir := range resp.GetWorkdirs() { - workdirs = append(workdirs, map[string]any{ - "path": workdir.GetPath(), - "conversation_count": workdir.GetConversationCount(), - "updated_at": workdir.GetUpdatedAt(), - }) - } - - _ = c.writeResponse(req.ID, map[string]any{ - "workdirs": workdirs, - }) -} - -func (c *websocketConnection) handleHistorySharedList(req websocketRequest) { - type payload struct { - Page int `json:"page"` - PageSize int `json:"page_size"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.shared_list payload") - return - } - page := body.Page - if page <= 0 { - page = defaultHistoryListPage - } - pageSize := body.PageSize - if pageSize <= 0 { - pageSize = defaultHistoryListPageSize - } else if pageSize > maxHistoryListLimit { - pageSize = maxHistoryListLimit - } - - argsJSON, err := json.Marshal(map[string]any{ - "page": page, - "page_size": pageSize, - }) - if err != nil { - _ = c.writeError(req.ID, "invalid history.shared_list payload") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_MemoryManage{ - MemoryManage: &gatewayv1.MemoryManageRequest{ - Command: "history_shared_list", - ArgsJson: string(argsJSON), - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetMemoryManageResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - var result struct { - Conversations []map[string]any `json:"conversations"` - TotalCount int `json:"total_count"` - } - if err := json.Unmarshal([]byte(resp.GetResultJson()), &result); err != nil { - _ = c.writeError(req.ID, "invalid history.shared_list response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "conversations": result.Conversations, - "total_count": result.TotalCount, - }) -} - -func (c *websocketConnection) handleHistoryGet(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - MaxMessages int32 `json:"max_messages"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.get payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryGet{ - HistoryGet: &gatewayv1.HistoryGetRequest{ - ConversationId: body.ConversationID, - MaxMessages: body.MaxMessages, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryGetResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "conversation_id": resp.GetConversationId(), - "messages_json": resp.GetMessagesJson(), - "total_message_count": resp.GetTotalMessageCount(), - "returned_message_count": resp.GetReturnedMessageCount(), - "has_more": resp.GetHasMore(), - "conversation": websocketConversationSummaryPayload(resp.GetConversation()), - }) -} - -func (c *websocketConnection) handleHistoryRename(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - Title string `json:"title"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.rename payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - if strings.TrimSpace(body.Title) == "" { - _ = c.writeError(req.ID, "title is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryRename{ - HistoryRename: &gatewayv1.HistoryRenameRequest{ - ConversationId: body.ConversationID, - Title: body.Title, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryRenameResp() - if resp == nil || resp.GetConversation() == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - conversation := resp.GetConversation() - _ = c.writeResponse(req.ID, websocketConversationSummaryPayload(conversation)) -} - -func (c *websocketConnection) handleHistoryPin(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - IsPinned bool `json:"is_pinned"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.pin payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryPin{ - HistoryPin: &gatewayv1.HistoryPinRequest{ - ConversationId: body.ConversationID, - IsPinned: body.IsPinned, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryPinResp() - if resp == nil || resp.GetConversation() == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketConversationSummaryPayload(resp.GetConversation())) -} - -func (c *websocketConnection) handleHistoryShareGet(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.share.get payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryShareGet{ - HistoryShareGet: &gatewayv1.HistoryShareGetRequest{ - ConversationId: body.ConversationID, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryShareGetResp() - if resp == nil || resp.GetShare() == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketHistoryShareStatusPayload(resp.GetShare())) -} - -func (c *websocketConnection) handleHistoryShareSet(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - Enabled bool `json:"enabled"` - RedactToolContent *bool `json:"redact_tool_content,omitempty"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.share.set payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryShareSet{ - HistoryShareSet: &gatewayv1.HistoryShareSetRequest{ - ConversationId: body.ConversationID, - Enabled: body.Enabled, - RedactToolContent: body.RedactToolContent, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryShareSetResp() - if resp == nil || resp.GetShare() == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, websocketHistoryShareStatusPayload(resp.GetShare())) -} - -func (c *websocketConnection) handleHistoryDelete(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.delete payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryDelete{ - HistoryDelete: &gatewayv1.HistoryDeleteRequest{ - ConversationId: body.ConversationID, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - if response.GetHistoryDeleteResp() == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{"ok": true}) -} - -func (c *websocketConnection) handleHistoryTruncate(req websocketRequest) { - type payload struct { - ConversationID string `json:"conversation_id"` - SegmentIndex int `json:"segment_index"` - MessageIndex int `json:"message_index"` - OmitMessagesJSON bool `json:"omit_messages_json"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid history.truncate payload") - return - } - if strings.TrimSpace(body.ConversationID) == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - if body.SegmentIndex < 0 { - _ = c.writeError(req.ID, "segment_index must be >= 0") - return - } - if body.MessageIndex < 0 { - _ = c.writeError(req.ID, "message_index must be >= 0") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_HistoryTruncate{ - HistoryTruncate: &gatewayv1.HistoryTruncateRequest{ - ConversationId: body.ConversationID, - SegmentIndex: int32(body.SegmentIndex), - MessageIndex: int32(body.MessageIndex), - OmitMessagesJson: body.OmitMessagesJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetHistoryTruncateResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - payloadMap := map[string]any{ - "conversation_id": resp.GetConversationId(), - "messages_json": resp.GetMessagesJson(), - } - if conversation := resp.GetConversation(); conversation != nil { - payloadMap["conversation"] = websocketConversationSummaryPayload(conversation) - } - - _ = c.writeResponse(req.ID, payloadMap) -} - -func (c *websocketConnection) handleProviderList(req websocketRequest) { - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_ProviderList{ - ProviderList: &gatewayv1.ProviderListRequest{}, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetProviderListResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - var payload any - raw := strings.TrimSpace(resp.GetProvidersJson()) - if raw == "" { - payload = []any{} - } else if err := json.Unmarshal([]byte(raw), &payload); err != nil { - _ = c.writeError(req.ID, "provider list response is not valid JSON") - return - } - - _ = c.writeResponse(req.ID, payload) -} - -func (c *websocketConnection) handleChatStart(req websocketRequest) { - var body handler.ChatRequestBody - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid chat.start payload") - return - } - body.Message = strings.TrimSpace(body.Message) - body.ConversationID = strings.TrimSpace(body.ConversationID) - body.ClientRequestID = strings.TrimSpace(body.ClientRequestID) - body.ExecutionMode = handler.NormalizeExecutionMode(body.ExecutionMode) - body.Workdir = handler.NormalizeWorkdir(body.Workdir) - body.SelectedSystemTools = handler.NormalizeSelectedSystemTools(body.SelectedSystemTools) - body.UploadedFiles = handler.NormalizeChatUploadedFiles(body.UploadedFiles) - body.RuntimeControls = handler.NormalizeChatRuntimeControls(body.RuntimeControls) - selectedModel, err := handler.NormalizeChatSelectedModel(body.SelectedModel) - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - body.SelectedModel = selectedModel - if body.Message == "" && len(body.UploadedFiles) == 0 { - _ = c.writeError(req.ID, "message is required") - return - } - if !c.sm.IsOnline() { - _ = c.writeError(req.ID, "agent offline") - return - } - - snapshot, created, err := c.sm.StartChatRunWithClientRequest( - req.ID, - body.ConversationID, - body.ClientRequestID, - body.Workdir, - ) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - sourceRequestID := snapshot.RequestID - if sourceRequestID == "" { - sourceRequestID = req.ID - } - eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( - sourceRequestID, - snapshot.ConversationID, - 0, - ) - if err != nil { - if created { - c.sm.RemoveChatRun(sourceRequestID) - } - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - defer cleanup() - - // Register before sending so the broadcast forwarder can skip the copy that - // this same connection already receives through the recoverable chat stream. - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - responseID := req.ID - c.registerActiveChat(responseID, sourceRequestID, snapshot.ConversationID, cancel) - defer c.releaseActiveChat(responseID) - - if created { - if err := c.sm.SendToAgent(&gatewayv1.GatewayEnvelope{ - RequestId: sourceRequestID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_ChatRequest{ - ChatRequest: &gatewayv1.ChatRequest{ - ConversationId: body.ConversationID, - ClientRequestId: body.ClientRequestID, - Message: body.Message, - SelectedModel: handler.ToProtoChatSelectedModel(body.SelectedModel), - RuntimeControls: handler.ToProtoChatRuntimeControls(body.RuntimeControls), - ExecutionMode: body.ExecutionMode, - Workdir: body.Workdir, - SelectedSystemTools: body.SelectedSystemTools, - UploadedFiles: handler.ToProtoChatUploadedFiles(body.UploadedFiles), - }, - }, - }); err != nil { - c.sm.RemoveChatRun(sourceRequestID) - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - } - - // Do not enforce a hard timeout for streaming chat requests. The GUI path can run - // multiple compaction rounds stably; WebUI should behave the same and only stop - // when the user cancels, the connection closes, or the agent returns done/error. - for { - select { - case <-c.done: - return - case <-ctx.Done(): - _ = c.writeError(responseID, websocketErrorMessage(ctx.Err())) - return - case <-eventDone: - return - case event, ok := <-eventCh: - if !ok { - return - } - chatEvent := event.Event - if chatEvent == nil { - continue - } - if chatEvent.GetConversationId() != "" { - body.ConversationID = strings.TrimSpace(chatEvent.GetConversationId()) - c.updateActiveChatConversationID(responseID, body.ConversationID) - } - if err := c.writeChatEvent(responseID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { - c.close() - return - } - if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { - return - } - } - } -} - -func (c *websocketConnection) handleChatResume(req websocketRequest) { - var body websocketChatResumePayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid chat.resume payload") - return - } - body.RequestID = strings.TrimSpace(body.RequestID) - body.ConversationID = strings.TrimSpace(body.ConversationID) - if body.RequestID == "" && body.ConversationID == "" { - _ = c.writeError(req.ID, "request_id or conversation_id is required") - return - } - if body.AfterSeq < 0 { - body.AfterSeq = 0 - } - - eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( - body.RequestID, - body.ConversationID, - body.AfterSeq, - ) - if err != nil { - responseID := body.RequestID - if responseID == "" { - responseID = req.ID - } - _ = c.writeError(responseID, websocketErrorMessage(err)) - return - } - defer cleanup() - - responseID := snapshot.RequestID - if responseID == "" { - responseID = body.RequestID - } - if responseID == "" { - responseID = req.ID - } - - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - c.registerActiveChat(responseID, snapshot.RequestID, snapshot.ConversationID, cancel) - defer c.releaseActiveChat(responseID) - - if snapshot.Done && snapshot.LatestSeq <= body.AfterSeq { - payload := map[string]any{ - "type": "done", - "seq": snapshot.LatestSeq, - } - if snapshot.ConversationID != "" { - payload["conversation_id"] = snapshot.ConversationID - } - if err := c.writeChatEvent(responseID, payload); err != nil { - c.close() - } - return - } - - for { - select { - case <-c.done: - return - case <-ctx.Done(): - _ = c.writeError(responseID, websocketErrorMessage(ctx.Err())) - return - case <-eventDone: - return - case event, ok := <-eventCh: - if !ok { - return - } - chatEvent := event.Event - if chatEvent == nil { - continue - } - if chatEvent.GetConversationId() != "" { - c.updateActiveChatConversationID(responseID, strings.TrimSpace(chatEvent.GetConversationId())) - } - if err := c.writeChatEvent(responseID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { - c.close() - return - } - if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { - return - } - } - } -} - -func (c *websocketConnection) handleChatAttach(req websocketRequest) { - var body websocketChatAttachPayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid chat.attach payload") - return - } - body.ConversationID = strings.TrimSpace(body.ConversationID) - if body.ConversationID == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - if body.AfterSeq < 0 { - body.AfterSeq = 0 - } - - eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( - "", - body.ConversationID, - body.AfterSeq, - ) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - defer cleanup() - - ctx, cancel := context.WithCancel(context.Background()) - c.registerActiveChatAttachment(req.ID, cancel) - defer c.releaseActiveChatAttachment(req.ID) - - if snapshot.Done && snapshot.LatestSeq <= body.AfterSeq { - payload := map[string]any{ - "type": "done", - "seq": snapshot.LatestSeq, - } - if snapshot.ConversationID != "" { - payload["conversation_id"] = snapshot.ConversationID - } - if err := c.writeChatEvent(req.ID, payload); err != nil { - c.close() - } - return - } - - for { - select { - case <-c.done: - return - case <-ctx.Done(): - return - case <-eventDone: - return - case event, ok := <-eventCh: - if !ok { - return - } - chatEvent := event.Event - if chatEvent == nil { - continue - } - if err := c.writeChatEvent(req.ID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { - c.close() - return - } - if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { - return - } - } - } -} - -func (c *websocketConnection) handleChatDetach(req websocketRequest) { - var body websocketChatDetachPayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid chat.detach payload") - return - } - targetRequestID := strings.TrimSpace(body.RequestID) - if targetRequestID == "" { - targetRequestID = req.ID - } - if targetRequestID == "" { - _ = c.writeError(req.ID, "request_id is required") - return - } - c.cancelActiveChatAttachment(targetRequestID) - _ = c.writeResponse(req.ID, map[string]any{"ok": true}) -} - -func (c *websocketConnection) handleChatCancel(req websocketRequest) { - var body handler.CancelChatRequestBody - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid chat.cancel payload") - return - } - body.ConversationID = strings.TrimSpace(body.ConversationID) - if body.ConversationID == "" { - _ = c.writeError(req.ID, "conversation_id is required") - return - } - if !c.sm.IsOnline() { - _ = c.writeError(req.ID, "agent offline") - return - } - - if err := c.sm.SendToAgent(&gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_CancelChat{ - CancelChat: &gatewayv1.CancelChatRequest{ - ConversationId: body.ConversationID, - }, - }, - }); err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - - c.cancelActiveChatsByConversation(body.ConversationID) - c.sm.RemoveChatRunByConversation(body.ConversationID) - _ = c.writeResponse(req.ID, map[string]any{"ok": true}) -} - -func (c *websocketConnection) handleUploadedImagePreview(req websocketRequest) { - var body handler.UploadedImagePreviewRequestBody - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid files.preview payload") - return - } - body.Workdir = strings.TrimSpace(body.Workdir) - body.AbsolutePath = strings.TrimSpace(body.AbsolutePath) - if body.Workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - if body.AbsolutePath == "" { - _ = c.writeError(req.ID, "absolute_path is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_UploadedImagePreview{ - UploadedImagePreview: &gatewayv1.UploadedImagePreviewRequest{ - Workdir: body.Workdir, - AbsolutePath: body.AbsolutePath, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetUploadedImagePreviewResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "mimeType": resp.GetMimeType(), - "data": resp.GetData(), - }) -} - -func (c *websocketConnection) handleMemoryManage(req websocketRequest) { - type payload struct { - Command string `json:"command"` - Args json.RawMessage `json:"args"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid memory.manage payload") - return - } - - command := strings.TrimSpace(body.Command) - if command == "" { - _ = c.writeError(req.ID, "command is required") - return - } - if !strings.HasPrefix(command, "memory_") { - _ = c.writeError(req.ID, "unsupported memory command") - return - } - - argsJSON := strings.TrimSpace(string(body.Args)) - if argsJSON == "" { - argsJSON = "{}" - } - if !json.Valid([]byte(argsJSON)) { - _ = c.writeError(req.ID, "memory args must be valid JSON") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_MemoryManage{ - MemoryManage: &gatewayv1.MemoryManageRequest{ - Command: command, - ArgsJson: argsJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetMemoryManageResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - payloadValue, err := websocketMemoryResultPayload(resp.GetResultJson()) - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - _ = c.writeResponse(req.ID, payloadValue) -} - -func terminalActionFromRequestType(requestType string) string { - return strings.TrimPrefix(strings.TrimSpace(requestType), "terminal.") -} - -func (c *websocketConnection) handleTerminalRequest(req websocketRequest) { - action := terminalActionFromRequestType(req.Type) - if !c.sm.WebTerminalEnabled() { - _ = c.writeError(req.ID, "web terminal is disabled in desktop Remote settings") - return - } - - var body websocketTerminalRequestPayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid "+req.Type+" payload") - return - } - - cols, err := websocketOptionalUint32(body.Cols, "cols") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - rows, err := websocketOptionalUint32(body.Rows, "rows") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - maxBytes, err := websocketOptionalUint32(body.MaxBytes, "max_bytes") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - projectPathKey := strings.TrimSpace(body.ProjectPathKey) - if action == "attach" || action == "snapshot" { - c.rememberTerminalSession(body.SessionID, projectPathKey) - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_TerminalRequest{ - TerminalRequest: &gatewayv1.TerminalRequest{ - Action: action, - SessionId: strings.TrimSpace(body.SessionID), - ProjectPathKey: projectPathKey, - Cwd: strings.TrimSpace(body.Cwd), - Shell: strings.TrimSpace(body.Shell), - Title: strings.TrimSpace(body.Title), - Data: body.Data, - Cols: cols, - Rows: rows, - MaxBytes: maxBytes, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetTerminalResponse() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - c.sm.ApplyTerminalResponseSnapshot(action, projectPathKey, resp) - c.rememberTerminalInterest(action, body, resp) - - _ = c.writeResponse(req.ID, websocketTerminalResponsePayload(resp)) -} - -func (c *websocketConnection) rememberTerminalInterest(action string, body websocketTerminalRequestPayload, resp *gatewayv1.TerminalResponse) { - projectPathKey := strings.TrimSpace(body.ProjectPathKey) - sessionID := strings.TrimSpace(body.SessionID) - if respSession := resp.GetSession(); respSession != nil { - if projectPathKey == "" { - projectPathKey = strings.TrimSpace(respSession.GetProjectPathKey()) - } - if sessionID == "" { - sessionID = strings.TrimSpace(respSession.GetId()) - } - } - - switch action { - case "list", "create", "close_project": - c.rememberTerminalProject(projectPathKey) - case "attach", "snapshot": - c.rememberTerminalSession(sessionID, projectPathKey) - } -} - -func (c *websocketConnection) handleTerminalDetach(req websocketRequest) { - if !c.sm.WebTerminalEnabled() { - _ = c.writeError(req.ID, "web terminal is disabled in desktop Remote settings") - return - } - var body websocketTerminalRequestPayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid terminal.detach payload") - return - } - c.forgetTerminalInterest(body.SessionID, body.ProjectPathKey) - _ = c.writeResponse(req.ID, map[string]any{"action": "detach"}) -} - -func gitActionFromRequestType(requestType string) string { - return strings.TrimPrefix(strings.TrimSpace(requestType), "git.") -} - -func gitActionIsWrite(action string) bool { - switch action { - case "init", "switch_branch", "create_branch", "stage", "stage_all", "unstage", "unstage_all", "discard", "discard_all", "add_to_gitignore", "commit", "fetch", "pull", "set_remote", "push": - return true - default: - return false - } -} - -func (c *websocketConnection) handleGitRequest(req websocketRequest) { - action := gitActionFromRequestType(req.Type) - if gitActionIsWrite(action) && !c.sm.WebGitEnabled() { - _ = c.writeError(req.ID, "web git is disabled in desktop Remote settings") - return - } - - var body websocketGitRequestPayload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid "+req.Type+" payload") - return - } - argsJSON := strings.TrimSpace(string(body.Args)) - if argsJSON == "" { - argsJSON = "{}" - } - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_GitRequest{ - GitRequest: &gatewayv1.GitRequest{ - Action: action, - Workdir: strings.TrimSpace(body.Workdir), - ArgsJson: argsJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - resp := response.GetGitResponse() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - payload, err := websocketGitResultPayload(resp.GetResultJson()) - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - _ = c.writeResponse(req.ID, payload) -} - -func (c *websocketConnection) handleCronManage(req websocketRequest) { - var body handler.CronManageRequestBody - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid cron.manage payload") - return - } - if !c.sm.IsOnline() { - _ = c.writeError(req.ID, "agent offline") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_CronManage{ - CronManage: &gatewayv1.CronManageRequest{ - Action: body.Action, - TaskId: body.TaskID, - TaskJson: body.TaskJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetCronManageResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "action": resp.GetAction(), - "result_json": resp.GetResultJson(), - }) -} - -func (c *websocketConnection) handleProviderModels(req websocketRequest) { - var body handler.ProviderModelsRequestBody - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid provider.models payload") - return - } - - ctx, cancel := context.WithTimeout(context.Background(), c.cfg.RequestTimeout) - defer cancel() - - result, err := handler.FetchProviderModels(ctx, body) - if err != nil { - var statusErr *handler.HTTPStatusError - if errors.As(err, &statusErr) { - _ = c.writeError(req.ID, statusErr.Message) - return - } - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - - var payload any - if err := json.Unmarshal(result.Body, &payload); err != nil { - _ = c.writeError(req.ID, "provider model response is not valid JSON") - return - } - - _ = c.writeResponse(req.ID, payload) -} - -func (c *websocketConnection) handleSettingsGet(req websocketRequest) { - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SettingsGet{ - SettingsGet: &gatewayv1.SettingsGetRequest{}, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - settingsResp := response.GetSettingsGetResp() - if settingsResp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - payload, err := websocketSettingsJSONPayload(settingsResp.GetSettingsJson()) - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - c.sm.ApplySettingsJSON(settingsResp.GetSettingsJson()) - - _ = c.writeResponse(req.ID, payload) -} - -func (c *websocketConnection) handleSettingsUpdate(req websocketRequest) { - payloadJSON, err := websocketRawPayloadJSON(req.Payload) - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SettingsUpdate{ - SettingsUpdate: &gatewayv1.SettingsUpdateRequest{ - SettingsJson: payloadJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - settingsResp := response.GetSettingsUpdateResp() - if settingsResp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - if settingsResp.GetAccepted() { - c.sm.ApplySettingsJSONPreservingRemote(payloadJSON) - } - - _ = c.writeResponse(req.ID, map[string]any{ - "accepted": settingsResp.GetAccepted(), - "message": strings.TrimSpace(settingsResp.GetMessage()), - }) -} - -func (c *websocketConnection) handleSkillFilesList(req websocketRequest) { - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SkillFilesList{ - SkillFilesList: &gatewayv1.SkillFilesListRequest{}, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetSkillFilesListResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "rootDir": resp.GetRootDir(), - "paths": resp.GetPaths(), - "truncated": resp.GetTruncated(), - }) -} - -func (c *websocketConnection) handleFileMentionList(req websocketRequest) { - type payload struct { - Workdir string `json:"workdir"` - MaxResults *int `json:"max_results"` - Query string `json:"query"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid mentions.list payload") - return - } - - workdir := strings.TrimSpace(body.Workdir) - if workdir == "" { - _ = c.writeError(req.ID, "workdir is required") - return - } - query := strings.TrimSpace(body.Query) - - maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_FileMentionList{ - FileMentionList: &gatewayv1.FileMentionListRequest{ - Workdir: workdir, - MaxResults: maxResults, - Query: query, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetFileMentionListResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - entries := make([]map[string]any, 0, len(resp.GetEntries())) - for _, entry := range resp.GetEntries() { - entries = append(entries, map[string]any{ - "path": entry.GetPath(), - "kind": entry.GetKind(), - }) - } - - _ = c.writeResponse(req.ID, map[string]any{ - "entries": entries, - "truncated": resp.GetTruncated(), - }) -} - -func (c *websocketConnection) handleSkillMetadataRead(req websocketRequest) { - type payload struct { - Path string `json:"path"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid skills.read-metadata payload") - return - } - - path := strings.TrimSpace(body.Path) - if path == "" { - _ = c.writeError(req.ID, "path is required") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SkillMetadataRead{ - SkillMetadataRead: &gatewayv1.SkillMetadataReadRequest{ - Path: path, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetSkillMetadataReadResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "name": nullableTrimmedString(resp.GetName()), - "description": nullableTrimmedString(resp.GetDescription()), - }) -} - -func (c *websocketConnection) handleSkillTextRead(req websocketRequest) { - type payload struct { - Path string `json:"path"` - Offset *int `json:"offset"` - Length *int `json:"length"` - } - - var body payload - if err := decodeWebSocketPayload(req.Payload, &body); err != nil { - _ = c.writeError(req.ID, "invalid skills.read-text payload") - return - } - - path := strings.TrimSpace(body.Path) - if path == "" { - _ = c.writeError(req.ID, "path is required") - return - } - - offset, err := websocketOptionalUint32(body.Offset, "offset") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - length, err := websocketOptionalUint32(body.Length, "length") - if err != nil { - _ = c.writeError(req.ID, err.Error()) - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SkillTextRead{ - SkillTextRead: &gatewayv1.SkillTextReadRequest{ - Path: path, - Offset: offset, - Length: length, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetSkillTextReadResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - _ = c.writeResponse(req.ID, map[string]any{ - "content": resp.GetContent(), - "truncated": resp.GetTruncated(), - }) -} - -func (c *websocketConnection) handleSkillManage(req websocketRequest) { - payloadJSON := strings.TrimSpace(string(req.Payload)) - if payloadJSON == "" || payloadJSON == "null" { - payloadJSON = "{}" - } - if !json.Valid([]byte(payloadJSON)) { - _ = c.writeError(req.ID, "invalid skills.manage payload") - return - } - - response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ - RequestId: req.ID, - Timestamp: time.Now().Unix(), - Payload: &gatewayv1.GatewayEnvelope_SkillManage{ - SkillManage: &gatewayv1.SkillManageRequest{ - PayloadJson: payloadJSON, - }, - }, - }) - if err != nil { - _ = c.writeError(req.ID, websocketErrorMessage(err)) - return - } - if errResp := response.GetError(); errResp != nil { - _ = c.writeError(req.ID, errResp.GetMessage()) - return - } - - resp := response.GetSkillManageResp() - if resp == nil { - _ = c.writeError(req.ID, "unexpected agent response") - return - } - - var payload any - raw := strings.TrimSpace(resp.GetResultJson()) - if raw == "" { - payload = map[string]any{} - } else if err := json.Unmarshal([]byte(raw), &payload); err != nil { - _ = c.writeError(req.ID, "skill manage response is not valid JSON") - return - } - - _ = c.writeResponse(req.ID, payload) -} - -func websocketFsListResponsePayload(resp *gatewayv1.FsListResponse) map[string]any { - entryPayload := make([]map[string]any, 0, len(resp.GetEntries())) - for _, entry := range resp.GetEntries() { - entryPayload = append(entryPayload, map[string]any{ - "path": entry.GetPath(), - "kind": entry.GetKind(), - }) - } - - var path any - if resp.GetHasPath() { - path = resp.GetPath() - } - - return map[string]any{ - "path": path, - "depth": resp.GetDepth(), - "offset": resp.GetOffset(), - "maxResults": resp.GetMaxResults(), - "total": resp.GetTotal(), - "hasMore": resp.GetHasMore(), - "entries": entryPayload, - } -} - -func websocketFsWriteTextResponsePayload(resp *gatewayv1.FsWriteTextResponse) map[string]any { - return map[string]any{ - "path": resp.GetPath(), - "mode": resp.GetMode(), - "existedBefore": resp.GetExistedBefore(), - "bytesWritten": resp.GetBytesWritten(), - "mtimeMs": resp.GetMtimeMs(), - "contentHash": resp.GetContentHash(), - "totalLines": resp.GetTotalLines(), - } -} - -func websocketFsCreateDirResponsePayload(resp *gatewayv1.FsCreateDirResponse) map[string]any { - return map[string]any{ - "path": resp.GetPath(), - "kind": resp.GetKind(), - } -} - -func websocketFsRenameResponsePayload(resp *gatewayv1.FsRenameResponse) map[string]any { - return map[string]any{ - "fromPath": resp.GetFromPath(), - "path": resp.GetPath(), - "kind": resp.GetKind(), - } -} - -func websocketFsDeleteResponsePayload(resp *gatewayv1.FsDeleteResponse) map[string]any { - return map[string]any{ - "path": resp.GetPath(), - "kind": resp.GetKind(), - } -} - -func nullableTrimmedString(value string) any { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return nil - } - return trimmed -} - -func websocketOptionalUint32(value *int, field string) (uint32, error) { - if value == nil { - return 0, nil - } - if *value < 0 { - return 0, errors.New(field + " must be >= 0") - } - return uint32(*value), nil -} - -func (c *websocketConnection) awaitAgentResponse( - requestID string, - envelope *gatewayv1.GatewayEnvelope, -) (*gatewayv1.AgentEnvelope, error) { - ctx, cancel := context.WithTimeout(context.Background(), c.cfg.RequestTimeout) - defer cancel() - - go func() { - select { - case <-c.done: - cancel() - case <-ctx.Done(): - } - }() - - return awaitAgentUnaryResponse(ctx, c.sm, requestID, envelope) -} - -func (c *websocketConnection) writeResponse(requestID string, payload any) error { - return c.writeEnvelope(websocketEnvelope{ - ID: requestID, - Type: "response", - Payload: payload, - }) -} - -func (c *websocketConnection) writeError(requestID string, message string) error { - return c.writeEnvelope(websocketEnvelope{ - ID: requestID, - Type: "error", - Error: message, - }) -} - -func (c *websocketConnection) writeChatEvent(requestID string, payload any) error { - return c.writeEnvelope(websocketEnvelope{ - ID: requestID, - Type: "chat.event", - Payload: payload, - }) -} - -func (c *websocketConnection) writeHistoryEvent(payload any) error { - return c.writeEnvelope(websocketEnvelope{ - Type: "history.event", - Payload: payload, - }) -} - -func (c *websocketConnection) writeConversationEvent(payload any) error { - return c.writeEnvelope(websocketEnvelope{ - Type: "conversation.event", - Payload: payload, - }) -} - -func (c *websocketConnection) writeSettingsEvent(payload any) error { - return c.writeEnvelope(websocketEnvelope{ - Type: "settings.event", - Payload: payload, - }) -} - -func (c *websocketConnection) writeTerminalEvent(payload any) error { - return c.writeEnvelope(websocketEnvelope{ - Type: "terminal.event", - Payload: payload, +func (c *websocketConnection) writeTerminalEvent(payload any) error { + return c.writeEnvelope(websocketEnvelope{ + Type: "terminal.event", + Payload: payload, }) } func (c *websocketConnection) writeEnvelope(envelope websocketEnvelope) error { - c.writeMu.Lock() - defer c.writeMu.Unlock() - if c.cfg.WebSocketWriteTimeout > 0 { - if err := c.conn.SetWriteDeadline(time.Now().Add(c.cfg.WebSocketWriteTimeout)); err != nil { - return err - } - defer func() { - _ = c.conn.SetWriteDeadline(time.Time{}) - }() - } - return websocket.JSON.Send(c.conn, envelope) -} - -func websocketConversationSummaryPayload(conversation *gatewayv1.ConversationSummary) map[string]any { - if conversation == nil { - return nil - } - - return map[string]any{ - "id": conversation.GetId(), - "title": conversation.GetTitle(), - "created_at": conversation.GetCreatedAt(), - "updated_at": conversation.GetUpdatedAt(), - "message_count": conversation.GetMessageCount(), - "provider_id": conversation.GetProviderId(), - "model": conversation.GetModel(), - "session_id": conversation.GetSessionId(), - "cwd": conversation.GetCwd(), - "is_pinned": conversation.GetIsPinned(), - "pinned_at": conversation.GetPinnedAt(), - "is_shared": conversation.GetIsShared(), - } -} - -func websocketActiveChatRunSummariesPayload(summaries []session.ActiveChatRunSummary) []map[string]any { - payload := make([]map[string]any, 0, len(summaries)) - for _, summary := range summaries { - conversationID := strings.TrimSpace(summary.ConversationID) - if conversationID == "" { - continue - } - payload = append(payload, map[string]any{ - "conversation_id": conversationID, - "cwd": strings.TrimSpace(summary.Workdir), - "updated_at": summary.UpdatedAt, - }) - } - return payload -} - -func websocketHistoryShareStatusPayload(share *gatewayv1.HistoryShareStatus) map[string]any { - if share == nil { - return nil - } - - return map[string]any{ - "conversation_id": share.GetConversationId(), - "enabled": share.GetEnabled(), - "token": share.GetToken(), - "created_at": share.GetCreatedAt(), - "updated_at": share.GetUpdatedAt(), - "redact_tool_content": share.GetRedactToolContent(), - } -} - -func websocketHistorySyncPayload(event *gatewayv1.HistorySyncEvent) map[string]any { - payload := map[string]any{ - "kind": strings.TrimSpace(event.GetKind()), - "conversation_id": strings.TrimSpace(event.GetConversationId()), - } - - if conversation := event.GetConversation(); conversation != nil { - payload["conversation"] = websocketConversationSummaryPayload(conversation) - } - - return payload -} - -func websocketSettingsSyncPayload(event *gatewayv1.SettingsSyncEvent) (map[string]any, error) { - return websocketSettingsJSONPayload(event.GetSettingsJson()) -} - -func websocketSettingsJSONPayload(raw string) (map[string]any, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return map[string]any{}, nil - } - - var payload map[string]any - if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { - return nil, errors.New("gateway settings payload is not valid JSON") - } - if payload == nil { - return map[string]any{}, nil - } - return payload, nil -} - -func websocketTerminalSessionPayload(session *gatewayv1.TerminalSession) map[string]any { - if session == nil { - return nil - } - payload := map[string]any{ - "id": strings.TrimSpace(session.GetId()), - "project_path_key": strings.TrimSpace(session.GetProjectPathKey()), - "cwd": strings.TrimSpace(session.GetCwd()), - "shell": strings.TrimSpace(session.GetShell()), - "title": strings.TrimSpace(session.GetTitle()), - "pid": session.GetPid(), - "cols": session.GetCols(), - "rows": session.GetRows(), - "created_at": session.GetCreatedAt(), - "updated_at": session.GetUpdatedAt(), - "finished_at": session.GetFinishedAt(), - "exit_code": session.GetExitCode(), - "running": session.GetRunning(), - } - if session.GetPid() == 0 { - payload["pid"] = nil - } - if session.GetFinishedAt() == 0 { - payload["finished_at"] = nil - } - return payload -} - -func websocketTerminalShellOptionPayload(option *gatewayv1.TerminalShellOption) map[string]any { - if option == nil { - return nil - } - return map[string]any{ - "id": strings.TrimSpace(option.GetId()), - "label": strings.TrimSpace(option.GetLabel()), - "command": strings.TrimSpace(option.GetCommand()), - } -} - -func websocketTerminalResponsePayload(resp *gatewayv1.TerminalResponse) map[string]any { - sessions := make([]map[string]any, 0, len(resp.GetSessions())) - for _, session := range resp.GetSessions() { - if payload := websocketTerminalSessionPayload(session); payload != nil { - sessions = append(sessions, payload) - } - } - shellOptions := make([]map[string]any, 0, len(resp.GetShellOptions())) - for _, option := range resp.GetShellOptions() { - if payload := websocketTerminalShellOptionPayload(option); payload != nil { - shellOptions = append(shellOptions, payload) - } - } - payload := map[string]any{ - "action": strings.TrimSpace(resp.GetAction()), - "sessions": sessions, - "output": resp.GetOutput(), - "truncated": resp.GetTruncated(), - "shell_options": shellOptions, - "default_shell": resp.GetDefaultShell(), - } - if resp.GetOutputStartOffset() != 0 || resp.GetOutputEndOffset() != 0 || resp.GetOutput() != "" { - payload["output_start_offset"] = resp.GetOutputStartOffset() - payload["output_end_offset"] = resp.GetOutputEndOffset() - } - if session := websocketTerminalSessionPayload(resp.GetSession()); session != nil { - payload["session"] = session - } - return payload -} - -func websocketTerminalEventPayload(event *gatewayv1.TerminalEvent) map[string]any { - payload := map[string]any{ - "kind": strings.TrimSpace(event.GetKind()), - "session_id": strings.TrimSpace(event.GetSessionId()), - "project_path_key": strings.TrimSpace(event.GetProjectPathKey()), - "data": event.GetData(), - } - if event.GetOutputStartOffset() != 0 || event.GetOutputEndOffset() != 0 || event.GetData() != "" { - payload["output_start_offset"] = event.GetOutputStartOffset() - payload["output_end_offset"] = event.GetOutputEndOffset() - } - if session := websocketTerminalSessionPayload(event.GetSession()); session != nil { - payload["session"] = session - } - return payload -} - -func websocketMemoryResultPayload(raw string) (any, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return map[string]any{}, nil - } - - var payload any - if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { - return nil, errors.New("gateway memory response is not valid JSON") - } - if payload == nil { - return map[string]any{}, nil - } - return payload, nil -} - -func websocketGitResultPayload(raw string) (any, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return map[string]any{}, nil - } - var payload any - if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { - return nil, errors.New("gateway git response is not valid JSON") - } - if payload == nil { - return map[string]any{}, nil - } - return payload, nil -} - -func websocketRawPayloadJSON(raw json.RawMessage) (string, error) { - trimmed := strings.TrimSpace(string(raw)) - if trimmed == "" { - return "{}", nil - } - - var payload map[string]any - if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { - return "", errors.New("invalid settings.update payload") - } - if payload == nil { - return "{}", nil - } - - normalized, err := json.Marshal(payload) - if err != nil { - return "", errors.New("invalid settings.update payload") - } - return string(normalized), nil -} - -func (c *websocketConnection) registerActiveChat( - requestID string, - sourceRequestID string, - conversationID string, - cancel context.CancelFunc, -) { - requestID = strings.TrimSpace(requestID) - sourceRequestID = strings.TrimSpace(sourceRequestID) - c.activeChatsMu.Lock() - defer c.activeChatsMu.Unlock() - c.activeChats[requestID] = &websocketChatState{ - cancel: cancel, - conversationID: strings.TrimSpace(conversationID), - sourceRequestID: sourceRequestID, - } - delete(c.recentChats, requestID) - delete(c.recentChats, sourceRequestID) -} - -func (c *websocketConnection) registerActiveChatAttachment(requestID string, cancel context.CancelFunc) { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return - } - c.activeChatAttachmentsMu.Lock() - defer c.activeChatAttachmentsMu.Unlock() - if existing := c.activeChatAttachments[requestID]; existing != nil { - existing() - } - c.activeChatAttachments[requestID] = cancel -} - -func (c *websocketConnection) releaseActiveChatAttachment(requestID string) { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return - } - c.activeChatAttachmentsMu.Lock() - delete(c.activeChatAttachments, requestID) - c.activeChatAttachmentsMu.Unlock() -} - -func (c *websocketConnection) releaseAllActiveChatAttachments() []context.CancelFunc { - c.activeChatAttachmentsMu.Lock() - defer c.activeChatAttachmentsMu.Unlock() - - cancels := make([]context.CancelFunc, 0, len(c.activeChatAttachments)) - for requestID, cancel := range c.activeChatAttachments { - delete(c.activeChatAttachments, requestID) - cancels = append(cancels, cancel) - } - return cancels -} - -func (c *websocketConnection) cancelActiveChatAttachment(requestID string) { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return - } - c.activeChatAttachmentsMu.Lock() - cancel := c.activeChatAttachments[requestID] - delete(c.activeChatAttachments, requestID) - c.activeChatAttachmentsMu.Unlock() - if cancel != nil { - cancel() - } -} - -func (c *websocketConnection) hasActiveChatRequest(requestID string) bool { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return false - } - c.activeChatsMu.Lock() - defer c.activeChatsMu.Unlock() - if _, ok := c.activeChats[requestID]; ok { - return true - } - for _, chat := range c.activeChats { - if chat.sourceRequestID == requestID { - return true - } - } - now := time.Now() - for recentRequestID, expiresAt := range c.recentChats { - if now.After(expiresAt) { - delete(c.recentChats, recentRequestID) - } - } - if expiresAt, ok := c.recentChats[requestID]; ok && now.Before(expiresAt) { - return true - } - return false -} - -func (c *websocketConnection) updateActiveChatConversationID(requestID string, conversationID string) { - c.activeChatsMu.Lock() - defer c.activeChatsMu.Unlock() - if chat, ok := c.activeChats[requestID]; ok { - chat.conversationID = strings.TrimSpace(conversationID) - } -} - -func (c *websocketConnection) releaseActiveChat(requestID string) *websocketChatState { - c.activeChatsMu.Lock() - defer c.activeChatsMu.Unlock() - chat := c.activeChats[requestID] - delete(c.activeChats, requestID) - expiresAt := time.Now().Add(recentActiveChatRetention) - if strings.TrimSpace(requestID) != "" { - c.recentChats[strings.TrimSpace(requestID)] = expiresAt - } - if chat != nil && chat.sourceRequestID != "" { - c.recentChats[chat.sourceRequestID] = expiresAt - } - return chat -} - -func (c *websocketConnection) releaseAllActiveChats() []*websocketChatState { - c.activeChatsMu.Lock() - defer c.activeChatsMu.Unlock() - - chats := make([]*websocketChatState, 0, len(c.activeChats)) - expiresAt := time.Now().Add(recentActiveChatRetention) - for requestID, chat := range c.activeChats { - delete(c.activeChats, requestID) - if strings.TrimSpace(requestID) != "" { - c.recentChats[strings.TrimSpace(requestID)] = expiresAt - } - if chat != nil && chat.sourceRequestID != "" { - c.recentChats[chat.sourceRequestID] = expiresAt - } - chats = append(chats, chat) - } - return chats -} - -func (c *websocketConnection) cancelActiveChatsByConversation(conversationID string) { - conversationID = strings.TrimSpace(conversationID) - if conversationID == "" { - return - } - - c.activeChatsMu.Lock() - chats := make([]*websocketChatState, 0, len(c.activeChats)) - expiresAt := time.Now().Add(recentActiveChatRetention) - for requestID, chat := range c.activeChats { - if chat.conversationID == conversationID { - delete(c.activeChats, requestID) - if strings.TrimSpace(requestID) != "" { - c.recentChats[strings.TrimSpace(requestID)] = expiresAt - } - if chat.sourceRequestID != "" { - c.recentChats[chat.sourceRequestID] = expiresAt - } - chats = append(chats, chat) - } - } - c.activeChatsMu.Unlock() - - for _, chat := range chats { - chat.cancel() - } -} - -func decodeWebSocketPayload(raw json.RawMessage, target any) error { - if len(raw) == 0 { - return json.Unmarshal([]byte("{}"), target) - } - decoder := json.NewDecoder(strings.NewReader(string(raw))) - decoder.DisallowUnknownFields() - return decoder.Decode(target) -} - -func waitForAgentEnvelope( - ctx context.Context, - ch <-chan *gatewayv1.AgentEnvelope, - done <-chan struct{}, -) (*gatewayv1.AgentEnvelope, error) { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-done: - return nil, session.ErrAgentOffline - case env, ok := <-ch: - if !ok { - return nil, session.ErrAgentOffline - } - return env, nil - } -} - -func awaitAgentUnaryResponse( - ctx context.Context, - sm *session.Manager, - requestID string, - envelope *gatewayv1.GatewayEnvelope, -) (*gatewayv1.AgentEnvelope, error) { - if !sm.IsOnline() { - return nil, session.ErrAgentOffline - } - - ch, done, cleanup, err := sm.RegisterStream(requestID) - if err != nil { - return nil, err - } - defer cleanup() - - if err := sm.SendToAgent(envelope); err != nil { - return nil, err - } - - return waitForAgentEnvelope(ctx, ch, done) -} - -func websocketErrorMessage(err error) string { - if err == nil { - return "request failed" - } - if errors.Is(err, context.DeadlineExceeded) { - return "request timed out" - } - if errors.Is(err, context.Canceled) { - return "request canceled" - } - if errors.Is(err, session.ErrAgentOffline) { - return "agent offline" - } - if errors.Is(err, session.ErrChatRunNotFound) { - return "chat stream not available" - } - return err.Error() -} - -func websocketChatEventPayload(event *gatewayv1.ChatEvent, seq int64, workdirInput ...string) map[string]any { - payload := map[string]any{ - "type": websocketChatEventType(event.GetType()), - } - if seq > 0 { - payload["seq"] = seq - } - if len(workdirInput) > 0 { - if workdir := strings.TrimSpace(workdirInput[0]); workdir != "" { - payload["workdir"] = workdir - } - } - - raw := strings.TrimSpace(event.GetData()) - if raw == "" { - raw = "{}" - } - - var decoded map[string]any - if err := json.Unmarshal([]byte(raw), &decoded); err == nil { - for key, value := range decoded { - payload[key] = value - } - } - - if conversationID := strings.TrimSpace(event.GetConversationId()); conversationID != "" { - payload["conversation_id"] = conversationID - } - - return payload -} - -func websocketChatEventType(eventType gatewayv1.ChatEvent_ChatEventType) string { - switch eventType { - case gatewayv1.ChatEvent_TOKEN: - return "token" - case gatewayv1.ChatEvent_THINKING: - return "thinking" - case gatewayv1.ChatEvent_TOOL_CALL: - return "tool_call" - case gatewayv1.ChatEvent_TOOL_RESULT: - return "tool_result" - case gatewayv1.ChatEvent_DONE: - return "done" - case gatewayv1.ChatEvent_ERROR: - return "error" - case gatewayv1.ChatEvent_TOOL_STATUS: - return "tool_status" - case gatewayv1.ChatEvent_HOSTED_SEARCH: - return "hosted_search" - default: - return "message" - } + return c.writer.write(envelope) } diff --git a/crates/agent-gateway/internal/server/websocket_chat_handlers.go b/crates/agent-gateway/internal/server/websocket_chat_handlers.go new file mode 100644 index 000000000..b050b3623 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_chat_handlers.go @@ -0,0 +1,449 @@ +package server + +import ( + "context" + "encoding/json" + "strings" + "time" + + "github.com/liveagent/agent-gateway/internal/handler" + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleChatStart(req websocketRequest) { + var body handler.ChatRequestBody + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid chat.start payload") + return + } + body.Message = strings.TrimSpace(body.Message) + body.ConversationID = strings.TrimSpace(body.ConversationID) + body.ClientRequestID = strings.TrimSpace(body.ClientRequestID) + body.ExecutionMode = handler.NormalizeExecutionMode(body.ExecutionMode) + body.Workdir = handler.NormalizeWorkdir(body.Workdir) + body.SelectedSystemTools = handler.NormalizeSelectedSystemTools(body.SelectedSystemTools) + body.UploadedFiles = handler.NormalizeChatUploadedFiles(body.UploadedFiles) + body.RuntimeControls = handler.NormalizeChatRuntimeControls(body.RuntimeControls) + selectedModel, err := handler.NormalizeChatSelectedModel(body.SelectedModel) + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + body.SelectedModel = selectedModel + if body.Message == "" && len(body.UploadedFiles) == 0 { + _ = c.writeError(req.ID, "message is required") + return + } + if !c.sm.IsOnline() { + _ = c.writeError(req.ID, "agent offline") + return + } + + snapshot, created, err := c.sm.StartChatRunWithClientRequest( + req.ID, + body.ConversationID, + body.ClientRequestID, + body.Workdir, + ) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + sourceRequestID := snapshot.RequestID + if sourceRequestID == "" { + sourceRequestID = req.ID + } + eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( + sourceRequestID, + snapshot.ConversationID, + 0, + ) + if err != nil { + if created { + c.sm.RemoveChatRun(sourceRequestID) + } + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + defer cleanup() + + // Register before sending so the broadcast forwarder can skip the copy that + // this same connection already receives through the recoverable chat stream. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + responseID := req.ID + c.registerActiveChat(responseID, sourceRequestID, snapshot.ConversationID, cancel) + defer c.releaseActiveChat(responseID) + + if created { + if err := c.sm.SendToAgent(&gatewayv1.GatewayEnvelope{ + RequestId: sourceRequestID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_ChatRequest{ + ChatRequest: &gatewayv1.ChatRequest{ + ConversationId: body.ConversationID, + ClientRequestId: body.ClientRequestID, + Message: body.Message, + SelectedModel: handler.ToProtoChatSelectedModel(body.SelectedModel), + RuntimeControls: handler.ToProtoChatRuntimeControls(body.RuntimeControls), + ExecutionMode: body.ExecutionMode, + Workdir: body.Workdir, + SelectedSystemTools: body.SelectedSystemTools, + UploadedFiles: handler.ToProtoChatUploadedFiles(body.UploadedFiles), + }, + }, + }); err != nil { + c.sm.RemoveChatRun(sourceRequestID) + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + } + + // Do not enforce a hard timeout for streaming chat requests. The GUI path can run + // multiple compaction rounds stably; WebUI should behave the same and only stop + // when the user cancels, the connection closes, or the agent returns done/error. + for { + select { + case <-c.done: + return + case <-ctx.Done(): + _ = c.writeError(responseID, websocketErrorMessage(ctx.Err())) + return + case <-eventDone: + return + case event, ok := <-eventCh: + if !ok { + return + } + chatEvent := event.Event + if chatEvent == nil { + continue + } + if chatEvent.GetConversationId() != "" { + body.ConversationID = strings.TrimSpace(chatEvent.GetConversationId()) + c.updateActiveChatConversationID(responseID, body.ConversationID) + } + if err := c.writeChatEvent(responseID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { + c.close() + return + } + if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { + return + } + } + } +} + +func (c *websocketConnection) handleChatResume(req websocketRequest) { + var body websocketChatResumePayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid chat.resume payload") + return + } + body.RequestID = strings.TrimSpace(body.RequestID) + body.ConversationID = strings.TrimSpace(body.ConversationID) + if body.RequestID == "" && body.ConversationID == "" { + _ = c.writeError(req.ID, "request_id or conversation_id is required") + return + } + if body.AfterSeq < 0 { + body.AfterSeq = 0 + } + + eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( + body.RequestID, + body.ConversationID, + body.AfterSeq, + ) + if err != nil { + responseID := body.RequestID + if responseID == "" { + responseID = req.ID + } + _ = c.writeError(responseID, websocketErrorMessage(err)) + return + } + defer cleanup() + + responseID := snapshot.RequestID + if responseID == "" { + responseID = body.RequestID + } + if responseID == "" { + responseID = req.ID + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + c.registerActiveChat(responseID, snapshot.RequestID, snapshot.ConversationID, cancel) + defer c.releaseActiveChat(responseID) + + if snapshot.Done && snapshot.LatestSeq <= body.AfterSeq { + payload := map[string]any{ + "type": "done", + "seq": snapshot.LatestSeq, + } + if snapshot.ConversationID != "" { + payload["conversation_id"] = snapshot.ConversationID + } + if err := c.writeChatEvent(responseID, payload); err != nil { + c.close() + } + return + } + + for { + select { + case <-c.done: + return + case <-ctx.Done(): + _ = c.writeError(responseID, websocketErrorMessage(ctx.Err())) + return + case <-eventDone: + return + case event, ok := <-eventCh: + if !ok { + return + } + chatEvent := event.Event + if chatEvent == nil { + continue + } + if chatEvent.GetConversationId() != "" { + c.updateActiveChatConversationID(responseID, strings.TrimSpace(chatEvent.GetConversationId())) + } + if err := c.writeChatEvent(responseID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { + c.close() + return + } + if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { + return + } + } + } +} + +func (c *websocketConnection) handleChatAttach(req websocketRequest) { + var body websocketChatAttachPayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid chat.attach payload") + return + } + body.ConversationID = strings.TrimSpace(body.ConversationID) + if body.ConversationID == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + if body.AfterSeq < 0 { + body.AfterSeq = 0 + } + + eventCh, eventDone, cleanup, snapshot, err := c.sm.SubscribeChatRun( + "", + body.ConversationID, + body.AfterSeq, + ) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + c.registerActiveChatAttachment(req.ID, cancel) + defer c.releaseActiveChatAttachment(req.ID) + + if snapshot.Done && snapshot.LatestSeq <= body.AfterSeq { + payload := map[string]any{ + "type": "done", + "seq": snapshot.LatestSeq, + } + if snapshot.ConversationID != "" { + payload["conversation_id"] = snapshot.ConversationID + } + if err := c.writeChatEvent(req.ID, payload); err != nil { + c.close() + } + return + } + + for { + select { + case <-c.done: + return + case <-ctx.Done(): + return + case <-eventDone: + return + case event, ok := <-eventCh: + if !ok { + return + } + chatEvent := event.Event + if chatEvent == nil { + continue + } + if err := c.writeChatEvent(req.ID, websocketChatEventPayload(chatEvent, event.Seq, event.Workdir)); err != nil { + c.close() + return + } + if chatEvent.GetType() == gatewayv1.ChatEvent_DONE || chatEvent.GetType() == gatewayv1.ChatEvent_ERROR { + return + } + } + } +} + +func (c *websocketConnection) handleChatDetach(req websocketRequest) { + var body websocketChatDetachPayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid chat.detach payload") + return + } + targetRequestID := strings.TrimSpace(body.RequestID) + if targetRequestID == "" { + targetRequestID = req.ID + } + if targetRequestID == "" { + _ = c.writeError(req.ID, "request_id is required") + return + } + c.cancelActiveChatAttachment(targetRequestID) + _ = c.writeResponse(req.ID, map[string]any{"ok": true}) +} + +func (c *websocketConnection) handleChatCancel(req websocketRequest) { + var body handler.CancelChatRequestBody + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid chat.cancel payload") + return + } + body.ConversationID = strings.TrimSpace(body.ConversationID) + if body.ConversationID == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + if !c.sm.IsOnline() { + _ = c.writeError(req.ID, "agent offline") + return + } + + if err := c.sm.SendToAgent(&gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_CancelChat{ + CancelChat: &gatewayv1.CancelChatRequest{ + ConversationId: body.ConversationID, + }, + }, + }); err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + + c.cancelActiveChatsByConversation(body.ConversationID) + c.sm.RemoveChatRunByConversation(body.ConversationID) + _ = c.writeResponse(req.ID, map[string]any{"ok": true}) +} + +func (c *websocketConnection) registerActiveChat( + requestID string, + sourceRequestID string, + conversationID string, + cancel context.CancelFunc, +) { + c.chatTracker.registerActive(requestID, sourceRequestID, conversationID, cancel) +} + +func (c *websocketConnection) registerActiveChatAttachment(requestID string, cancel context.CancelFunc) { + c.chatTracker.registerAttachment(requestID, cancel) +} + +func (c *websocketConnection) releaseActiveChatAttachment(requestID string) { + c.chatTracker.releaseAttachment(requestID) +} + +func (c *websocketConnection) releaseAllActiveChatAttachments() []context.CancelFunc { + return c.chatTracker.releaseAllAttachments() +} + +func (c *websocketConnection) cancelActiveChatAttachment(requestID string) { + c.chatTracker.cancelAttachment(requestID) +} + +func (c *websocketConnection) hasActiveChatRequest(requestID string) bool { + return c.chatTracker.hasActiveRequest(requestID) +} + +func (c *websocketConnection) updateActiveChatConversationID(requestID string, conversationID string) { + c.chatTracker.updateConversationID(requestID, conversationID) +} + +func (c *websocketConnection) releaseActiveChat(requestID string) *websocketChatState { + return c.chatTracker.releaseActive(requestID) +} + +func (c *websocketConnection) releaseAllActiveChats() []*websocketChatState { + return c.chatTracker.releaseAllActive() +} + +func (c *websocketConnection) cancelActiveChatsByConversation(conversationID string) { + for _, chat := range c.chatTracker.cancelByConversation(conversationID) { + chat.cancel() + } +} + +func websocketChatEventPayload(event *gatewayv1.ChatEvent, seq int64, workdirInput ...string) map[string]any { + payload := map[string]any{ + "type": websocketChatEventType(event.GetType()), + } + if seq > 0 { + payload["seq"] = seq + } + if len(workdirInput) > 0 { + if workdir := strings.TrimSpace(workdirInput[0]); workdir != "" { + payload["workdir"] = workdir + } + } + + raw := strings.TrimSpace(event.GetData()) + if raw == "" { + raw = "{}" + } + + var decoded map[string]any + if err := json.Unmarshal([]byte(raw), &decoded); err == nil { + for key, value := range decoded { + payload[key] = value + } + } + + if conversationID := strings.TrimSpace(event.GetConversationId()); conversationID != "" { + payload["conversation_id"] = conversationID + } + + return payload +} + +func websocketChatEventType(eventType gatewayv1.ChatEvent_ChatEventType) string { + switch eventType { + case gatewayv1.ChatEvent_TOKEN: + return "token" + case gatewayv1.ChatEvent_THINKING: + return "thinking" + case gatewayv1.ChatEvent_TOOL_CALL: + return "tool_call" + case gatewayv1.ChatEvent_TOOL_RESULT: + return "tool_result" + case gatewayv1.ChatEvent_DONE: + return "done" + case gatewayv1.ChatEvent_ERROR: + return "error" + case gatewayv1.ChatEvent_TOOL_STATUS: + return "tool_status" + case gatewayv1.ChatEvent_HOSTED_SEARCH: + return "hosted_search" + default: + return "message" + } +} diff --git a/crates/agent-gateway/internal/server/websocket_connection_state.go b/crates/agent-gateway/internal/server/websocket_connection_state.go new file mode 100644 index 000000000..d7a719ef8 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_connection_state.go @@ -0,0 +1,263 @@ +package server + +import ( + "context" + "strings" + "sync" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +type websocketChatTracker struct { + activeMu sync.RWMutex + active map[string]*websocketChatState + recent map[string]time.Time + + attachmentsMu sync.Mutex + attachments map[string]context.CancelFunc +} + +func newWebsocketChatTracker() *websocketChatTracker { + return &websocketChatTracker{ + active: make(map[string]*websocketChatState), + recent: make(map[string]time.Time), + attachments: make(map[string]context.CancelFunc), + } +} + +func (t *websocketChatTracker) registerActive( + requestID string, + sourceRequestID string, + conversationID string, + cancel context.CancelFunc, +) { + requestID = strings.TrimSpace(requestID) + sourceRequestID = strings.TrimSpace(sourceRequestID) + t.activeMu.Lock() + defer t.activeMu.Unlock() + t.active[requestID] = &websocketChatState{ + cancel: cancel, + conversationID: strings.TrimSpace(conversationID), + sourceRequestID: sourceRequestID, + } + delete(t.recent, requestID) + delete(t.recent, sourceRequestID) +} + +func (t *websocketChatTracker) registerAttachment(requestID string, cancel context.CancelFunc) { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return + } + t.attachmentsMu.Lock() + defer t.attachmentsMu.Unlock() + if existing := t.attachments[requestID]; existing != nil { + existing() + } + t.attachments[requestID] = cancel +} + +func (t *websocketChatTracker) releaseAttachment(requestID string) { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return + } + t.attachmentsMu.Lock() + delete(t.attachments, requestID) + t.attachmentsMu.Unlock() +} + +func (t *websocketChatTracker) releaseAllAttachments() []context.CancelFunc { + t.attachmentsMu.Lock() + defer t.attachmentsMu.Unlock() + + cancels := make([]context.CancelFunc, 0, len(t.attachments)) + for requestID, cancel := range t.attachments { + delete(t.attachments, requestID) + cancels = append(cancels, cancel) + } + return cancels +} + +func (t *websocketChatTracker) cancelAttachment(requestID string) { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return + } + t.attachmentsMu.Lock() + cancel := t.attachments[requestID] + delete(t.attachments, requestID) + t.attachmentsMu.Unlock() + if cancel != nil { + cancel() + } +} + +func (t *websocketChatTracker) hasActiveRequest(requestID string) bool { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return false + } + t.activeMu.Lock() + defer t.activeMu.Unlock() + if _, ok := t.active[requestID]; ok { + return true + } + for _, chat := range t.active { + if chat.sourceRequestID == requestID { + return true + } + } + now := time.Now() + for recentRequestID, expiresAt := range t.recent { + if now.After(expiresAt) { + delete(t.recent, recentRequestID) + } + } + if expiresAt, ok := t.recent[requestID]; ok && now.Before(expiresAt) { + return true + } + return false +} + +func (t *websocketChatTracker) updateConversationID(requestID string, conversationID string) { + t.activeMu.Lock() + defer t.activeMu.Unlock() + if chat, ok := t.active[requestID]; ok { + chat.conversationID = strings.TrimSpace(conversationID) + } +} + +func (t *websocketChatTracker) releaseActive(requestID string) *websocketChatState { + t.activeMu.Lock() + defer t.activeMu.Unlock() + chat := t.active[requestID] + delete(t.active, requestID) + expiresAt := time.Now().Add(recentActiveChatRetention) + if strings.TrimSpace(requestID) != "" { + t.recent[strings.TrimSpace(requestID)] = expiresAt + } + if chat != nil && chat.sourceRequestID != "" { + t.recent[chat.sourceRequestID] = expiresAt + } + return chat +} + +func (t *websocketChatTracker) releaseAllActive() []*websocketChatState { + t.activeMu.Lock() + defer t.activeMu.Unlock() + + chats := make([]*websocketChatState, 0, len(t.active)) + expiresAt := time.Now().Add(recentActiveChatRetention) + for requestID, chat := range t.active { + delete(t.active, requestID) + if strings.TrimSpace(requestID) != "" { + t.recent[strings.TrimSpace(requestID)] = expiresAt + } + if chat != nil && chat.sourceRequestID != "" { + t.recent[chat.sourceRequestID] = expiresAt + } + chats = append(chats, chat) + } + return chats +} + +func (t *websocketChatTracker) cancelByConversation(conversationID string) []*websocketChatState { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return nil + } + + t.activeMu.Lock() + chats := make([]*websocketChatState, 0, len(t.active)) + expiresAt := time.Now().Add(recentActiveChatRetention) + for requestID, chat := range t.active { + if chat.conversationID == conversationID { + delete(t.active, requestID) + if strings.TrimSpace(requestID) != "" { + t.recent[strings.TrimSpace(requestID)] = expiresAt + } + if chat.sourceRequestID != "" { + t.recent[chat.sourceRequestID] = expiresAt + } + chats = append(chats, chat) + } + } + t.activeMu.Unlock() + + return chats +} + +type websocketTerminalInterestTracker struct { + mu sync.RWMutex + projects map[string]struct{} + sessions map[string]struct{} +} + +func newWebsocketTerminalInterestTracker() *websocketTerminalInterestTracker { + return &websocketTerminalInterestTracker{ + projects: make(map[string]struct{}), + sessions: make(map[string]struct{}), + } +} + +func (t *websocketTerminalInterestTracker) rememberProject(projectPathKey string) { + projectPathKey = strings.TrimSpace(projectPathKey) + if projectPathKey == "" { + return + } + t.mu.Lock() + t.projects[projectPathKey] = struct{}{} + t.mu.Unlock() +} + +func (t *websocketTerminalInterestTracker) rememberSession(sessionID string, projectPathKey string) { + sessionID = strings.TrimSpace(sessionID) + projectPathKey = strings.TrimSpace(projectPathKey) + if sessionID == "" && projectPathKey == "" { + return + } + t.mu.Lock() + if sessionID != "" { + t.sessions[sessionID] = struct{}{} + } + if projectPathKey != "" { + t.projects[projectPathKey] = struct{}{} + } + t.mu.Unlock() +} + +func (t *websocketTerminalInterestTracker) forget(sessionID string, projectPathKey string) { + sessionID = strings.TrimSpace(sessionID) + projectPathKey = strings.TrimSpace(projectPathKey) + t.mu.Lock() + if sessionID != "" { + delete(t.sessions, sessionID) + } + if sessionID == "" && projectPathKey != "" { + delete(t.projects, projectPathKey) + } + t.mu.Unlock() +} + +func (t *websocketTerminalInterestTracker) shouldForward(event *gatewayv1.TerminalEvent) bool { + if event == nil { + return false + } + sessionID := strings.TrimSpace(event.GetSessionId()) + projectPathKey := strings.TrimSpace(event.GetProjectPathKey()) + kind := strings.TrimSpace(event.GetKind()) + + // Terminal metadata changes are broadcast so each browser tab can keep its + // project list fresh; raw output remains gated behind explicit attachment. + if kind != "output" { + return sessionID != "" || projectPathKey != "" + } + + t.mu.RLock() + _, sessionSubscribed := t.sessions[sessionID] + t.mu.RUnlock() + + return sessionID != "" && sessionSubscribed +} diff --git a/crates/agent-gateway/internal/server/websocket_connection_state_test.go b/crates/agent-gateway/internal/server/websocket_connection_state_test.go new file mode 100644 index 000000000..86f8d9741 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_connection_state_test.go @@ -0,0 +1,68 @@ +package server + +import ( + "testing" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func TestWebsocketTerminalInterestTrackerFiltersOutputBySession(t *testing.T) { + t.Parallel() + + tracker := newWebsocketTerminalInterestTracker() + outputEvent := &gatewayv1.TerminalEvent{ + Kind: "output", + SessionId: "session-1", + ProjectPathKey: "project-1", + } + metadataEvent := &gatewayv1.TerminalEvent{ + Kind: "created", + SessionId: "session-1", + ProjectPathKey: "project-1", + } + + if tracker.shouldForward(outputEvent) { + t.Fatal("output should not forward before a session is attached") + } + if !tracker.shouldForward(metadataEvent) { + t.Fatal("metadata should forward so project/session lists stay fresh") + } + + tracker.rememberSession("session-1", "project-1") + if !tracker.shouldForward(outputEvent) { + t.Fatal("output should forward after attaching the session") + } + + tracker.forget("session-1", "project-1") + if tracker.shouldForward(outputEvent) { + t.Fatal("output should stop forwarding after detaching the session") + } +} + +func TestWebsocketChatTrackerKeepsRecentReleasedRequests(t *testing.T) { + t.Parallel() + + tracker := newWebsocketChatTracker() + called := false + tracker.registerActive("request-1", "source-1", "conversation-1", func() { + called = true + }) + + if !tracker.hasActiveRequest("request-1") { + t.Fatal("expected request id to be active") + } + if !tracker.hasActiveRequest("source-1") { + t.Fatal("expected source request id to be active") + } + + state := tracker.releaseActive("request-1") + if state == nil || state.conversationID != "conversation-1" { + t.Fatalf("released chat state = %#v", state) + } + if called { + t.Fatal("releaseActive should not cancel by itself") + } + if !tracker.hasActiveRequest("request-1") || !tracker.hasActiveRequest("source-1") { + t.Fatal("released chat should remain briefly discoverable for broadcast dedupe") + } +} diff --git a/crates/agent-gateway/internal/server/websocket_cron_handlers.go b/crates/agent-gateway/internal/server/websocket_cron_handlers.go new file mode 100644 index 000000000..ead7ae3b8 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_cron_handlers.go @@ -0,0 +1,51 @@ +package server + +import ( + "time" + + "github.com/liveagent/agent-gateway/internal/handler" + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleCronManage(req websocketRequest) { + var body handler.CronManageRequestBody + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid cron.manage payload") + return + } + if !c.sm.IsOnline() { + _ = c.writeError(req.ID, "agent offline") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_CronManage{ + CronManage: &gatewayv1.CronManageRequest{ + Action: body.Action, + TaskId: body.TaskID, + TaskJson: body.TaskJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetCronManageResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "action": resp.GetAction(), + "result_json": resp.GetResultJson(), + }) +} diff --git a/crates/agent-gateway/internal/server/websocket_fs_handlers.go b/crates/agent-gateway/internal/server/websocket_fs_handlers.go new file mode 100644 index 000000000..ec3d8f807 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_fs_handlers.go @@ -0,0 +1,534 @@ +package server + +import ( + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleFsRoots(req websocketRequest) { + // Payload is intentionally empty; we still decode to reject unexpected fields. + var body struct{} + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.roots payload") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsRoots{ + FsRoots: &gatewayv1.FsRootsRequest{}, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsRootsResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + rootPayload := make([]map[string]any, 0, len(resp.GetRoots())) + for _, root := range resp.GetRoots() { + rootPayload = append(rootPayload, map[string]any{ + "id": root.GetId(), + "path": root.GetPath(), + "kind": root.GetKind(), + "label": root.GetLabel(), + }) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "roots": rootPayload, + }) +} + +func (c *websocketConnection) handleFsListDirs(req websocketRequest) { + type payload struct { + Path string `json:"path"` + MaxResults *int `json:"max_results"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.list_dirs payload") + return + } + + dir := strings.TrimSpace(body.Path) + if dir == "" { + _ = c.writeError(req.ID, "path is required") + return + } + + maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsListDirs{ + FsListDirs: &gatewayv1.FsListDirsRequest{ + Path: dir, + MaxResults: maxResults, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsListDirsResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + entryPayload := make([]map[string]any, 0, len(resp.GetEntries())) + for _, entry := range resp.GetEntries() { + entryPayload = append(entryPayload, map[string]any{ + "path": entry.GetPath(), + "name": entry.GetName(), + }) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "path": strings.TrimSpace(resp.GetPath()), + "entries": entryPayload, + "truncated": resp.GetTruncated(), + }) +} + +func (c *websocketConnection) handleFsCreateProjectFolder(req websocketRequest) { + type payload struct { + Parent string `json:"parent"` + Name string `json:"name"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.create_project_folder payload") + return + } + + parent := strings.TrimSpace(body.Parent) + name := strings.TrimSpace(body.Name) + if parent == "" { + _ = c.writeError(req.ID, "parent is required") + return + } + if name == "" { + _ = c.writeError(req.ID, "name is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsCreateProjectFolder{ + FsCreateProjectFolder: &gatewayv1.FsCreateProjectFolderRequest{ + Parent: parent, + Name: name, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsCreateProjectFolderResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "path": strings.TrimSpace(resp.GetPath()), + }) +} + +func (c *websocketConnection) handleFsList(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + Path string `json:"path"` + Depth *int `json:"depth"` + Offset *int `json:"offset"` + MaxResults *int `json:"max_results"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.list payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + + depth, err := websocketOptionalUint32(body.Depth, "depth") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + offset, err := websocketOptionalUint32(body.Offset, "offset") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsList{ + FsList: &gatewayv1.FsListRequest{ + Workdir: workdir, + Path: strings.TrimSpace(body.Path), + Depth: depth, + Offset: offset, + MaxResults: maxResults, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsListResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketFsListResponsePayload(resp)) +} + +func (c *websocketConnection) handleFsWriteText(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + Path string `json:"path"` + Content string `json:"content"` + Mode string `json:"mode"` + ExpectedMtimeMs *uint64 `json:"expected_mtime_ms"` + ExpectedContentHash *string `json:"expected_content_hash"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.write_text payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + path := strings.TrimSpace(body.Path) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + if path == "" { + _ = c.writeError(req.ID, "path is required") + return + } + mode := strings.TrimSpace(body.Mode) + if mode == "" { + mode = "rewrite" + } + expectedHash := "" + hasExpectedHash := false + if body.ExpectedContentHash != nil { + expectedHash = strings.TrimSpace(*body.ExpectedContentHash) + hasExpectedHash = true + } + expectedMtime := uint64(0) + hasExpectedMtime := false + if body.ExpectedMtimeMs != nil { + expectedMtime = *body.ExpectedMtimeMs + hasExpectedMtime = true + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsWriteText{ + FsWriteText: &gatewayv1.FsWriteTextRequest{ + Workdir: workdir, + Path: path, + Content: body.Content, + Mode: mode, + ExpectedMtimeMs: expectedMtime, + ExpectedContentHash: expectedHash, + HasExpectedMtimeMs: hasExpectedMtime, + HasExpectedContentHash: hasExpectedHash, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsWriteTextResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketFsWriteTextResponsePayload(resp)) +} + +func (c *websocketConnection) handleFsCreateDir(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + Path string `json:"path"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.create_dir payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + path := strings.TrimSpace(body.Path) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + if path == "" { + _ = c.writeError(req.ID, "path is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsCreateDir{ + FsCreateDir: &gatewayv1.FsCreateDirRequest{ + Workdir: workdir, + Path: path, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsCreateDirResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketFsCreateDirResponsePayload(resp)) +} + +func (c *websocketConnection) handleFsRename(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + FromPath string `json:"from_path"` + ToPath string `json:"to_path"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.rename payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + fromPath := strings.TrimSpace(body.FromPath) + toPath := strings.TrimSpace(body.ToPath) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + if fromPath == "" { + _ = c.writeError(req.ID, "from_path is required") + return + } + if toPath == "" { + _ = c.writeError(req.ID, "to_path is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsRename{ + FsRename: &gatewayv1.FsRenameRequest{ + Workdir: workdir, + FromPath: fromPath, + ToPath: toPath, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsRenameResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketFsRenameResponsePayload(resp)) +} + +func (c *websocketConnection) handleFsDelete(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + Path string `json:"path"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid fs.delete payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + path := strings.TrimSpace(body.Path) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + if path == "" { + _ = c.writeError(req.ID, "path is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FsDelete{ + FsDelete: &gatewayv1.FsDeleteRequest{ + Workdir: workdir, + Path: path, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFsDeleteResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketFsDeleteResponsePayload(resp)) +} + +func websocketFsListResponsePayload(resp *gatewayv1.FsListResponse) map[string]any { + entryPayload := make([]map[string]any, 0, len(resp.GetEntries())) + for _, entry := range resp.GetEntries() { + entryPayload = append(entryPayload, map[string]any{ + "path": entry.GetPath(), + "kind": entry.GetKind(), + }) + } + + var path any + if resp.GetHasPath() { + path = resp.GetPath() + } + + return map[string]any{ + "path": path, + "depth": resp.GetDepth(), + "offset": resp.GetOffset(), + "maxResults": resp.GetMaxResults(), + "total": resp.GetTotal(), + "hasMore": resp.GetHasMore(), + "entries": entryPayload, + } +} + +func websocketFsWriteTextResponsePayload(resp *gatewayv1.FsWriteTextResponse) map[string]any { + return map[string]any{ + "path": resp.GetPath(), + "mode": resp.GetMode(), + "existedBefore": resp.GetExistedBefore(), + "bytesWritten": resp.GetBytesWritten(), + "mtimeMs": resp.GetMtimeMs(), + "contentHash": resp.GetContentHash(), + "totalLines": resp.GetTotalLines(), + } +} + +func websocketFsCreateDirResponsePayload(resp *gatewayv1.FsCreateDirResponse) map[string]any { + return map[string]any{ + "path": resp.GetPath(), + "kind": resp.GetKind(), + } +} + +func websocketFsRenameResponsePayload(resp *gatewayv1.FsRenameResponse) map[string]any { + return map[string]any{ + "fromPath": resp.GetFromPath(), + "path": resp.GetPath(), + "kind": resp.GetKind(), + } +} + +func websocketFsDeleteResponsePayload(resp *gatewayv1.FsDeleteResponse) map[string]any { + return map[string]any{ + "path": resp.GetPath(), + "kind": resp.GetKind(), + } +} diff --git a/crates/agent-gateway/internal/server/websocket_git_handlers.go b/crates/agent-gateway/internal/server/websocket_git_handlers.go new file mode 100644 index 000000000..336c1af23 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_git_handlers.go @@ -0,0 +1,69 @@ +package server + +import ( + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func gitActionFromRequestType(requestType string) string { + return strings.TrimPrefix(strings.TrimSpace(requestType), "git.") +} + +func gitActionIsWrite(action string) bool { + switch action { + case "init", "switch_branch", "create_branch", "stage", "stage_all", "unstage", "unstage_all", "discard", "discard_all", "add_to_gitignore", "commit", "fetch", "pull", "set_remote", "push": + return true + default: + return false + } +} + +func (c *websocketConnection) handleGitRequest(req websocketRequest) { + action := gitActionFromRequestType(req.Type) + if gitActionIsWrite(action) && !c.sm.WebGitEnabled() { + _ = c.writeError(req.ID, "web git is disabled in desktop Remote settings") + return + } + + var body websocketGitRequestPayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid "+req.Type+" payload") + return + } + argsJSON := strings.TrimSpace(string(body.Args)) + if argsJSON == "" { + argsJSON = "{}" + } + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_GitRequest{ + GitRequest: &gatewayv1.GitRequest{ + Action: action, + Workdir: strings.TrimSpace(body.Workdir), + ArgsJson: argsJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + resp := response.GetGitResponse() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + payload, err := websocketGitResultPayload(resp.GetResultJson()) + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + _ = c.writeResponse(req.ID, payload) +} diff --git a/crates/agent-gateway/internal/server/websocket_history_handlers.go b/crates/agent-gateway/internal/server/websocket_history_handlers.go new file mode 100644 index 000000000..9ad675080 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_history_handlers.go @@ -0,0 +1,523 @@ +package server + +import ( + "encoding/json" + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleHistoryList(req websocketRequest) { + type payload struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + Cwd string `json:"cwd"` + CwdEmpty bool `json:"cwd_empty"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.list payload") + return + } + page := body.Page + if page <= 0 { + page = defaultHistoryListPage + } + pageSize := body.PageSize + if pageSize <= 0 { + pageSize = defaultHistoryListPageSize + } else if pageSize > maxHistoryListLimit { + pageSize = maxHistoryListLimit + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryList{ + HistoryList: &gatewayv1.HistoryListRequest{ + Page: int32(page), + PageSize: int32(pageSize), + Cwd: strings.TrimSpace(body.Cwd), + CwdEmpty: body.CwdEmpty, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryListResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + conversations := make([]map[string]any, 0, len(resp.GetConversations())) + for _, conversation := range resp.GetConversations() { + conversations = append(conversations, websocketConversationSummaryPayload(conversation)) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "conversations": conversations, + "total_count": resp.GetTotalCount(), + "running_conversation_ids": c.sm.ActiveChatRunConversationIDs(), + "running_conversations": websocketActiveChatRunSummariesPayload(c.sm.ActiveChatRunSummaries()), + }) +} + +func (c *websocketConnection) handleHistoryWorkdirs(req websocketRequest) { + var body struct{} + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.workdirs payload") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryWorkdirs{ + HistoryWorkdirs: &gatewayv1.HistoryWorkdirsRequest{}, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryWorkdirsResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + workdirs := make([]map[string]any, 0, len(resp.GetWorkdirs())) + for _, workdir := range resp.GetWorkdirs() { + workdirs = append(workdirs, map[string]any{ + "path": workdir.GetPath(), + "conversation_count": workdir.GetConversationCount(), + "updated_at": workdir.GetUpdatedAt(), + }) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "workdirs": workdirs, + }) +} + +func (c *websocketConnection) handleHistorySharedList(req websocketRequest) { + type payload struct { + Page int `json:"page"` + PageSize int `json:"page_size"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.shared_list payload") + return + } + page := body.Page + if page <= 0 { + page = defaultHistoryListPage + } + pageSize := body.PageSize + if pageSize <= 0 { + pageSize = defaultHistoryListPageSize + } else if pageSize > maxHistoryListLimit { + pageSize = maxHistoryListLimit + } + + argsJSON, err := json.Marshal(map[string]any{ + "page": page, + "page_size": pageSize, + }) + if err != nil { + _ = c.writeError(req.ID, "invalid history.shared_list payload") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_MemoryManage{ + MemoryManage: &gatewayv1.MemoryManageRequest{ + Command: "history_shared_list", + ArgsJson: string(argsJSON), + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetMemoryManageResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + var result struct { + Conversations []map[string]any `json:"conversations"` + TotalCount int `json:"total_count"` + } + if err := json.Unmarshal([]byte(resp.GetResultJson()), &result); err != nil { + _ = c.writeError(req.ID, "invalid history.shared_list response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "conversations": result.Conversations, + "total_count": result.TotalCount, + }) +} + +func (c *websocketConnection) handleHistoryGet(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + MaxMessages int32 `json:"max_messages"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.get payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryGet{ + HistoryGet: &gatewayv1.HistoryGetRequest{ + ConversationId: body.ConversationID, + MaxMessages: body.MaxMessages, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryGetResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "conversation_id": resp.GetConversationId(), + "messages_json": resp.GetMessagesJson(), + "total_message_count": resp.GetTotalMessageCount(), + "returned_message_count": resp.GetReturnedMessageCount(), + "has_more": resp.GetHasMore(), + "conversation": websocketConversationSummaryPayload(resp.GetConversation()), + }) +} + +func (c *websocketConnection) handleHistoryRename(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + Title string `json:"title"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.rename payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + if strings.TrimSpace(body.Title) == "" { + _ = c.writeError(req.ID, "title is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryRename{ + HistoryRename: &gatewayv1.HistoryRenameRequest{ + ConversationId: body.ConversationID, + Title: body.Title, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryRenameResp() + if resp == nil || resp.GetConversation() == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + conversation := resp.GetConversation() + _ = c.writeResponse(req.ID, websocketConversationSummaryPayload(conversation)) +} + +func (c *websocketConnection) handleHistoryPin(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + IsPinned bool `json:"is_pinned"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.pin payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryPin{ + HistoryPin: &gatewayv1.HistoryPinRequest{ + ConversationId: body.ConversationID, + IsPinned: body.IsPinned, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryPinResp() + if resp == nil || resp.GetConversation() == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketConversationSummaryPayload(resp.GetConversation())) +} + +func (c *websocketConnection) handleHistoryShareGet(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.share.get payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryShareGet{ + HistoryShareGet: &gatewayv1.HistoryShareGetRequest{ + ConversationId: body.ConversationID, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryShareGetResp() + if resp == nil || resp.GetShare() == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketHistoryShareStatusPayload(resp.GetShare())) +} + +func (c *websocketConnection) handleHistoryShareSet(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + Enabled bool `json:"enabled"` + RedactToolContent *bool `json:"redact_tool_content,omitempty"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.share.set payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryShareSet{ + HistoryShareSet: &gatewayv1.HistoryShareSetRequest{ + ConversationId: body.ConversationID, + Enabled: body.Enabled, + RedactToolContent: body.RedactToolContent, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryShareSetResp() + if resp == nil || resp.GetShare() == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, websocketHistoryShareStatusPayload(resp.GetShare())) +} + +func (c *websocketConnection) handleHistoryDelete(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.delete payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryDelete{ + HistoryDelete: &gatewayv1.HistoryDeleteRequest{ + ConversationId: body.ConversationID, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + if response.GetHistoryDeleteResp() == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{"ok": true}) +} + +func (c *websocketConnection) handleHistoryTruncate(req websocketRequest) { + type payload struct { + ConversationID string `json:"conversation_id"` + SegmentIndex int `json:"segment_index"` + MessageIndex int `json:"message_index"` + OmitMessagesJSON bool `json:"omit_messages_json"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid history.truncate payload") + return + } + if strings.TrimSpace(body.ConversationID) == "" { + _ = c.writeError(req.ID, "conversation_id is required") + return + } + if body.SegmentIndex < 0 { + _ = c.writeError(req.ID, "segment_index must be >= 0") + return + } + if body.MessageIndex < 0 { + _ = c.writeError(req.ID, "message_index must be >= 0") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_HistoryTruncate{ + HistoryTruncate: &gatewayv1.HistoryTruncateRequest{ + ConversationId: body.ConversationID, + SegmentIndex: int32(body.SegmentIndex), + MessageIndex: int32(body.MessageIndex), + OmitMessagesJson: body.OmitMessagesJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetHistoryTruncateResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + payloadMap := map[string]any{ + "conversation_id": resp.GetConversationId(), + "messages_json": resp.GetMessagesJson(), + } + if conversation := resp.GetConversation(); conversation != nil { + payloadMap["conversation"] = websocketConversationSummaryPayload(conversation) + } + + _ = c.writeResponse(req.ID, payloadMap) +} diff --git a/crates/agent-gateway/internal/server/websocket_memory_handlers.go b/crates/agent-gateway/internal/server/websocket_memory_handlers.go new file mode 100644 index 000000000..92c18ecc5 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_memory_handlers.go @@ -0,0 +1,73 @@ +package server + +import ( + "encoding/json" + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleMemoryManage(req websocketRequest) { + type payload struct { + Command string `json:"command"` + Args json.RawMessage `json:"args"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid memory.manage payload") + return + } + + command := strings.TrimSpace(body.Command) + if command == "" { + _ = c.writeError(req.ID, "command is required") + return + } + if !strings.HasPrefix(command, "memory_") { + _ = c.writeError(req.ID, "unsupported memory command") + return + } + + argsJSON := strings.TrimSpace(string(body.Args)) + if argsJSON == "" { + argsJSON = "{}" + } + if !json.Valid([]byte(argsJSON)) { + _ = c.writeError(req.ID, "memory args must be valid JSON") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_MemoryManage{ + MemoryManage: &gatewayv1.MemoryManageRequest{ + Command: command, + ArgsJson: argsJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetMemoryManageResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + payloadValue, err := websocketMemoryResultPayload(resp.GetResultJson()) + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + _ = c.writeResponse(req.ID, payloadValue) +} diff --git a/crates/agent-gateway/internal/server/websocket_payloads.go b/crates/agent-gateway/internal/server/websocket_payloads.go new file mode 100644 index 000000000..b54b98bb2 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_payloads.go @@ -0,0 +1,252 @@ +package server + +import ( + "encoding/json" + "errors" + "strings" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" + "github.com/liveagent/agent-gateway/internal/session" +) + +func websocketConversationSummaryPayload(conversation *gatewayv1.ConversationSummary) map[string]any { + if conversation == nil { + return nil + } + + return map[string]any{ + "id": conversation.GetId(), + "title": conversation.GetTitle(), + "created_at": conversation.GetCreatedAt(), + "updated_at": conversation.GetUpdatedAt(), + "message_count": conversation.GetMessageCount(), + "provider_id": conversation.GetProviderId(), + "model": conversation.GetModel(), + "session_id": conversation.GetSessionId(), + "cwd": conversation.GetCwd(), + "is_pinned": conversation.GetIsPinned(), + "pinned_at": conversation.GetPinnedAt(), + "is_shared": conversation.GetIsShared(), + } +} + +func websocketActiveChatRunSummariesPayload(summaries []session.ActiveChatRunSummary) []map[string]any { + payload := make([]map[string]any, 0, len(summaries)) + for _, summary := range summaries { + conversationID := strings.TrimSpace(summary.ConversationID) + if conversationID == "" { + continue + } + payload = append(payload, map[string]any{ + "conversation_id": conversationID, + "cwd": strings.TrimSpace(summary.Workdir), + "updated_at": summary.UpdatedAt, + }) + } + return payload +} + +func websocketHistoryShareStatusPayload(share *gatewayv1.HistoryShareStatus) map[string]any { + if share == nil { + return nil + } + + return map[string]any{ + "conversation_id": share.GetConversationId(), + "enabled": share.GetEnabled(), + "token": share.GetToken(), + "created_at": share.GetCreatedAt(), + "updated_at": share.GetUpdatedAt(), + "redact_tool_content": share.GetRedactToolContent(), + } +} + +func websocketHistorySyncPayload(event *gatewayv1.HistorySyncEvent) map[string]any { + payload := map[string]any{ + "kind": strings.TrimSpace(event.GetKind()), + "conversation_id": strings.TrimSpace(event.GetConversationId()), + } + + if conversation := event.GetConversation(); conversation != nil { + payload["conversation"] = websocketConversationSummaryPayload(conversation) + } + + return payload +} + +func websocketSettingsSyncPayload(event *gatewayv1.SettingsSyncEvent) (map[string]any, error) { + return websocketSettingsJSONPayload(event.GetSettingsJson()) +} + +func websocketSettingsJSONPayload(raw string) (map[string]any, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return map[string]any{}, nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return nil, errors.New("gateway settings payload is not valid JSON") + } + if payload == nil { + return map[string]any{}, nil + } + return payload, nil +} + +func websocketTerminalSessionPayload(session *gatewayv1.TerminalSession) map[string]any { + if session == nil { + return nil + } + payload := map[string]any{ + "id": strings.TrimSpace(session.GetId()), + "project_path_key": strings.TrimSpace(session.GetProjectPathKey()), + "cwd": strings.TrimSpace(session.GetCwd()), + "shell": strings.TrimSpace(session.GetShell()), + "title": strings.TrimSpace(session.GetTitle()), + "pid": session.GetPid(), + "cols": session.GetCols(), + "rows": session.GetRows(), + "created_at": session.GetCreatedAt(), + "updated_at": session.GetUpdatedAt(), + "finished_at": session.GetFinishedAt(), + "exit_code": session.GetExitCode(), + "running": session.GetRunning(), + } + if session.GetPid() == 0 { + payload["pid"] = nil + } + if session.GetFinishedAt() == 0 { + payload["finished_at"] = nil + } + return payload +} + +func websocketTerminalShellOptionPayload(option *gatewayv1.TerminalShellOption) map[string]any { + if option == nil { + return nil + } + return map[string]any{ + "id": strings.TrimSpace(option.GetId()), + "label": strings.TrimSpace(option.GetLabel()), + "command": strings.TrimSpace(option.GetCommand()), + } +} + +func websocketTerminalResponsePayload(resp *gatewayv1.TerminalResponse) map[string]any { + sessions := make([]map[string]any, 0, len(resp.GetSessions())) + for _, session := range resp.GetSessions() { + if payload := websocketTerminalSessionPayload(session); payload != nil { + sessions = append(sessions, payload) + } + } + shellOptions := make([]map[string]any, 0, len(resp.GetShellOptions())) + for _, option := range resp.GetShellOptions() { + if payload := websocketTerminalShellOptionPayload(option); payload != nil { + shellOptions = append(shellOptions, payload) + } + } + payload := map[string]any{ + "action": strings.TrimSpace(resp.GetAction()), + "sessions": sessions, + "output": resp.GetOutput(), + "truncated": resp.GetTruncated(), + "shell_options": shellOptions, + "default_shell": resp.GetDefaultShell(), + } + if resp.GetOutputStartOffset() != 0 || resp.GetOutputEndOffset() != 0 || resp.GetOutput() != "" { + payload["output_start_offset"] = resp.GetOutputStartOffset() + payload["output_end_offset"] = resp.GetOutputEndOffset() + } + if session := websocketTerminalSessionPayload(resp.GetSession()); session != nil { + payload["session"] = session + } + return payload +} + +func websocketTerminalEventPayload(event *gatewayv1.TerminalEvent) map[string]any { + payload := map[string]any{ + "kind": strings.TrimSpace(event.GetKind()), + "session_id": strings.TrimSpace(event.GetSessionId()), + "project_path_key": strings.TrimSpace(event.GetProjectPathKey()), + "data": event.GetData(), + } + if event.GetOutputStartOffset() != 0 || event.GetOutputEndOffset() != 0 || event.GetData() != "" { + payload["output_start_offset"] = event.GetOutputStartOffset() + payload["output_end_offset"] = event.GetOutputEndOffset() + } + if session := websocketTerminalSessionPayload(event.GetSession()); session != nil { + payload["session"] = session + } + return payload +} + +func websocketMemoryResultPayload(raw string) (any, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return map[string]any{}, nil + } + + var payload any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return nil, errors.New("gateway memory response is not valid JSON") + } + if payload == nil { + return map[string]any{}, nil + } + return payload, nil +} + +func websocketGitResultPayload(raw string) (any, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return map[string]any{}, nil + } + var payload any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return nil, errors.New("gateway git response is not valid JSON") + } + if payload == nil { + return map[string]any{}, nil + } + return payload, nil +} + +func websocketRawPayloadJSON(raw json.RawMessage) (string, error) { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" { + return "{}", nil + } + + var payload map[string]any + if err := json.Unmarshal([]byte(trimmed), &payload); err != nil { + return "", errors.New("invalid settings.update payload") + } + if payload == nil { + return "{}", nil + } + + normalized, err := json.Marshal(payload) + if err != nil { + return "", errors.New("invalid settings.update payload") + } + return string(normalized), nil +} + +func nullableTrimmedString(value string) any { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil + } + return trimmed +} + +func websocketOptionalUint32(value *int, field string) (uint32, error) { + if value == nil { + return 0, nil + } + if *value < 0 { + return 0, errors.New(field + " must be >= 0") + } + return uint32(*value), nil +} diff --git a/crates/agent-gateway/internal/server/websocket_provider_handlers.go b/crates/agent-gateway/internal/server/websocket_provider_handlers.go new file mode 100644 index 000000000..70907022f --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_provider_handlers.go @@ -0,0 +1,77 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "strings" + "time" + + "github.com/liveagent/agent-gateway/internal/handler" + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleProviderList(req websocketRequest) { + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_ProviderList{ + ProviderList: &gatewayv1.ProviderListRequest{}, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetProviderListResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + var payload any + raw := strings.TrimSpace(resp.GetProvidersJson()) + if raw == "" { + payload = []any{} + } else if err := json.Unmarshal([]byte(raw), &payload); err != nil { + _ = c.writeError(req.ID, "provider list response is not valid JSON") + return + } + + _ = c.writeResponse(req.ID, payload) +} + +func (c *websocketConnection) handleProviderModels(req websocketRequest) { + var body handler.ProviderModelsRequestBody + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid provider.models payload") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), c.cfg.RequestTimeout) + defer cancel() + + result, err := handler.FetchProviderModels(ctx, body) + if err != nil { + var statusErr *handler.HTTPStatusError + if errors.As(err, &statusErr) { + _ = c.writeError(req.ID, statusErr.Message) + return + } + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + + var payload any + if err := json.Unmarshal(result.Body, &payload); err != nil { + _ = c.writeError(req.ID, "provider model response is not valid JSON") + return + } + + _ = c.writeResponse(req.ID, payload) +} diff --git a/crates/agent-gateway/internal/server/websocket_roundtrip.go b/crates/agent-gateway/internal/server/websocket_roundtrip.go new file mode 100644 index 000000000..3d7836aa6 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_roundtrip.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "strings" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" + "github.com/liveagent/agent-gateway/internal/session" +) + +func decodeWebSocketPayload(raw json.RawMessage, target any) error { + if len(raw) == 0 { + return json.Unmarshal([]byte("{}"), target) + } + decoder := json.NewDecoder(strings.NewReader(string(raw))) + decoder.DisallowUnknownFields() + return decoder.Decode(target) +} + +func waitForAgentEnvelope( + ctx context.Context, + ch <-chan *gatewayv1.AgentEnvelope, + done <-chan struct{}, +) (*gatewayv1.AgentEnvelope, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-done: + return nil, session.ErrAgentOffline + case env, ok := <-ch: + if !ok { + return nil, session.ErrAgentOffline + } + return env, nil + } +} + +func awaitAgentUnaryResponse( + ctx context.Context, + sm *session.Manager, + requestID string, + envelope *gatewayv1.GatewayEnvelope, +) (*gatewayv1.AgentEnvelope, error) { + if !sm.IsOnline() { + return nil, session.ErrAgentOffline + } + + ch, done, cleanup, err := sm.RegisterStream(requestID) + if err != nil { + return nil, err + } + defer cleanup() + + if err := sm.SendToAgent(envelope); err != nil { + return nil, err + } + + return waitForAgentEnvelope(ctx, ch, done) +} + +func websocketErrorMessage(err error) string { + if err == nil { + return "request failed" + } + if errors.Is(err, context.DeadlineExceeded) { + return "request timed out" + } + if errors.Is(err, context.Canceled) { + return "request canceled" + } + if errors.Is(err, session.ErrAgentOffline) { + return "agent offline" + } + if errors.Is(err, session.ErrChatRunNotFound) { + return "chat stream not available" + } + return err.Error() +} diff --git a/crates/agent-gateway/internal/server/websocket_routes.go b/crates/agent-gateway/internal/server/websocket_routes.go new file mode 100644 index 000000000..1e81bf2f7 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_routes.go @@ -0,0 +1,76 @@ +package server + +type websocketRequestHandler func(*websocketConnection, websocketRequest) + +var websocketRequestHandlers = map[string]websocketRequestHandler{ + "status.get": func(c *websocketConnection, req websocketRequest) { + _ = c.writeResponse(req.ID, c.sm.Status()) + }, + "fs.roots": (*websocketConnection).handleFsRoots, + "fs.list_dirs": (*websocketConnection).handleFsListDirs, + "fs.create_project_folder": (*websocketConnection).handleFsCreateProjectFolder, + "fs.list": (*websocketConnection).handleFsList, + "fs.write_text": (*websocketConnection).handleFsWriteText, + "fs.create_dir": (*websocketConnection).handleFsCreateDir, + "fs.rename": (*websocketConnection).handleFsRename, + "fs.delete": (*websocketConnection).handleFsDelete, + "history.list": (*websocketConnection).handleHistoryList, + "history.workdirs": (*websocketConnection).handleHistoryWorkdirs, + "history.shared_list": (*websocketConnection).handleHistorySharedList, + "history.get": (*websocketConnection).handleHistoryGet, + "history.rename": (*websocketConnection).handleHistoryRename, + "history.pin": (*websocketConnection).handleHistoryPin, + "history.share.get": (*websocketConnection).handleHistoryShareGet, + "history.share.set": (*websocketConnection).handleHistoryShareSet, + "history.delete": (*websocketConnection).handleHistoryDelete, + "history.truncate": (*websocketConnection).handleHistoryTruncate, + "providers.list": (*websocketConnection).handleProviderList, + "settings.get": (*websocketConnection).handleSettingsGet, + "settings.update": (*websocketConnection).handleSettingsUpdate, + "skills.list": (*websocketConnection).handleSkillFilesList, + "mentions.list": (*websocketConnection).handleFileMentionList, + "skills.read-metadata": (*websocketConnection).handleSkillMetadataRead, + "skills.read-text": (*websocketConnection).handleSkillTextRead, + "skills.manage": (*websocketConnection).handleSkillManage, + "chat.start": (*websocketConnection).handleChatStart, + "chat.resume": (*websocketConnection).handleChatResume, + "chat.attach": (*websocketConnection).handleChatAttach, + "chat.detach": (*websocketConnection).handleChatDetach, + "chat.cancel": (*websocketConnection).handleChatCancel, + "files.preview": (*websocketConnection).handleUploadedImagePreview, + "memory.manage": (*websocketConnection).handleMemoryManage, + "terminal.shell_options": (*websocketConnection).handleTerminalRequest, + "terminal.list": (*websocketConnection).handleTerminalRequest, + "terminal.create": (*websocketConnection).handleTerminalRequest, + "terminal.attach": (*websocketConnection).handleTerminalRequest, + "terminal.input": (*websocketConnection).handleTerminalRequest, + "terminal.resize": (*websocketConnection).handleTerminalRequest, + "terminal.rename": (*websocketConnection).handleTerminalRequest, + "terminal.close": (*websocketConnection).handleTerminalRequest, + "terminal.close_project": (*websocketConnection).handleTerminalRequest, + "terminal.detach": (*websocketConnection).handleTerminalDetach, + "git.status": (*websocketConnection).handleGitRequest, + "git.branches": (*websocketConnection).handleGitRequest, + "git.init": (*websocketConnection).handleGitRequest, + "git.switch_branch": (*websocketConnection).handleGitRequest, + "git.create_branch": (*websocketConnection).handleGitRequest, + "git.diff": (*websocketConnection).handleGitRequest, + "git.log": (*websocketConnection).handleGitRequest, + "git.commit_details": (*websocketConnection).handleGitRequest, + "git.compare_commit_with_remote": (*websocketConnection).handleGitRequest, + "git.commit_diff": (*websocketConnection).handleGitRequest, + "git.stage": (*websocketConnection).handleGitRequest, + "git.stage_all": (*websocketConnection).handleGitRequest, + "git.unstage": (*websocketConnection).handleGitRequest, + "git.unstage_all": (*websocketConnection).handleGitRequest, + "git.discard": (*websocketConnection).handleGitRequest, + "git.discard_all": (*websocketConnection).handleGitRequest, + "git.add_to_gitignore": (*websocketConnection).handleGitRequest, + "git.commit": (*websocketConnection).handleGitRequest, + "git.fetch": (*websocketConnection).handleGitRequest, + "git.pull": (*websocketConnection).handleGitRequest, + "git.set_remote": (*websocketConnection).handleGitRequest, + "git.push": (*websocketConnection).handleGitRequest, + "cron.manage": (*websocketConnection).handleCronManage, + "provider.models": (*websocketConnection).handleProviderModels, +} diff --git a/crates/agent-gateway/internal/server/websocket_routes_test.go b/crates/agent-gateway/internal/server/websocket_routes_test.go new file mode 100644 index 000000000..d8685c23a --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_routes_test.go @@ -0,0 +1,103 @@ +package server + +import "testing" + +func TestWebsocketRequestHandlersCoverKnownProtocolTypes(t *testing.T) { + t.Parallel() + + expectedTypes := []string{ + "status.get", + "fs.roots", + "fs.list_dirs", + "fs.create_project_folder", + "fs.list", + "fs.write_text", + "fs.create_dir", + "fs.rename", + "fs.delete", + "history.list", + "history.workdirs", + "history.shared_list", + "history.get", + "history.rename", + "history.pin", + "history.share.get", + "history.share.set", + "history.delete", + "history.truncate", + "providers.list", + "settings.get", + "settings.update", + "skills.list", + "mentions.list", + "skills.read-metadata", + "skills.read-text", + "skills.manage", + "chat.start", + "chat.resume", + "chat.attach", + "chat.detach", + "chat.cancel", + "files.preview", + "memory.manage", + "terminal.shell_options", + "terminal.list", + "terminal.create", + "terminal.attach", + "terminal.input", + "terminal.resize", + "terminal.rename", + "terminal.close", + "terminal.close_project", + "terminal.detach", + "git.status", + "git.branches", + "git.init", + "git.switch_branch", + "git.create_branch", + "git.diff", + "git.log", + "git.commit_details", + "git.compare_commit_with_remote", + "git.commit_diff", + "git.stage", + "git.stage_all", + "git.unstage", + "git.unstage_all", + "git.discard", + "git.discard_all", + "git.add_to_gitignore", + "git.commit", + "git.fetch", + "git.pull", + "git.set_remote", + "git.push", + "cron.manage", + "provider.models", + } + + for _, requestType := range expectedTypes { + if websocketRequestHandlers[requestType] == nil { + t.Fatalf("websocketRequestHandlers[%q] is missing", requestType) + } + } + if got := len(websocketRequestHandlers); got != len(expectedTypes) { + t.Fatalf("websocketRequestHandlers has %d entries, want %d", got, len(expectedTypes)) + } +} + +func TestDecodeWebSocketPayloadRejectsUnknownFields(t *testing.T) { + t.Parallel() + + var empty struct{} + if err := decodeWebSocketPayload(nil, &empty); err != nil { + t.Fatalf("decode empty payload: %v", err) + } + + var payload struct { + Known string `json:"known"` + } + if err := decodeWebSocketPayload([]byte(`{"known":"ok","unknown":true}`), &payload); err == nil { + t.Fatal("expected unknown payload field to be rejected") + } +} diff --git a/crates/agent-gateway/internal/server/websocket_settings_handlers.go b/crates/agent-gateway/internal/server/websocket_settings_handlers.go new file mode 100644 index 000000000..8276f623a --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_settings_handlers.go @@ -0,0 +1,81 @@ +package server + +import ( + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleSettingsGet(req websocketRequest) { + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SettingsGet{ + SettingsGet: &gatewayv1.SettingsGetRequest{}, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + settingsResp := response.GetSettingsGetResp() + if settingsResp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + payload, err := websocketSettingsJSONPayload(settingsResp.GetSettingsJson()) + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + c.sm.ApplySettingsJSON(settingsResp.GetSettingsJson()) + + _ = c.writeResponse(req.ID, payload) +} + +func (c *websocketConnection) handleSettingsUpdate(req websocketRequest) { + payloadJSON, err := websocketRawPayloadJSON(req.Payload) + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SettingsUpdate{ + SettingsUpdate: &gatewayv1.SettingsUpdateRequest{ + SettingsJson: payloadJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + settingsResp := response.GetSettingsUpdateResp() + if settingsResp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + if settingsResp.GetAccepted() { + c.sm.ApplySettingsJSONPreservingRemote(payloadJSON) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "accepted": settingsResp.GetAccepted(), + "message": strings.TrimSpace(settingsResp.GetMessage()), + }) +} diff --git a/crates/agent-gateway/internal/server/websocket_skills_handlers.go b/crates/agent-gateway/internal/server/websocket_skills_handlers.go new file mode 100644 index 000000000..a8a51e790 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_skills_handlers.go @@ -0,0 +1,260 @@ +package server + +import ( + "encoding/json" + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleSkillFilesList(req websocketRequest) { + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SkillFilesList{ + SkillFilesList: &gatewayv1.SkillFilesListRequest{}, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetSkillFilesListResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "rootDir": resp.GetRootDir(), + "paths": resp.GetPaths(), + "truncated": resp.GetTruncated(), + }) +} + +func (c *websocketConnection) handleFileMentionList(req websocketRequest) { + type payload struct { + Workdir string `json:"workdir"` + MaxResults *int `json:"max_results"` + Query string `json:"query"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid mentions.list payload") + return + } + + workdir := strings.TrimSpace(body.Workdir) + if workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + query := strings.TrimSpace(body.Query) + + maxResults, err := websocketOptionalUint32(body.MaxResults, "max_results") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_FileMentionList{ + FileMentionList: &gatewayv1.FileMentionListRequest{ + Workdir: workdir, + MaxResults: maxResults, + Query: query, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetFileMentionListResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + entries := make([]map[string]any, 0, len(resp.GetEntries())) + for _, entry := range resp.GetEntries() { + entries = append(entries, map[string]any{ + "path": entry.GetPath(), + "kind": entry.GetKind(), + }) + } + + _ = c.writeResponse(req.ID, map[string]any{ + "entries": entries, + "truncated": resp.GetTruncated(), + }) +} + +func (c *websocketConnection) handleSkillMetadataRead(req websocketRequest) { + type payload struct { + Path string `json:"path"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid skills.read-metadata payload") + return + } + + path := strings.TrimSpace(body.Path) + if path == "" { + _ = c.writeError(req.ID, "path is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SkillMetadataRead{ + SkillMetadataRead: &gatewayv1.SkillMetadataReadRequest{ + Path: path, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetSkillMetadataReadResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "name": nullableTrimmedString(resp.GetName()), + "description": nullableTrimmedString(resp.GetDescription()), + }) +} + +func (c *websocketConnection) handleSkillTextRead(req websocketRequest) { + type payload struct { + Path string `json:"path"` + Offset *int `json:"offset"` + Length *int `json:"length"` + } + + var body payload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid skills.read-text payload") + return + } + + path := strings.TrimSpace(body.Path) + if path == "" { + _ = c.writeError(req.ID, "path is required") + return + } + + offset, err := websocketOptionalUint32(body.Offset, "offset") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + length, err := websocketOptionalUint32(body.Length, "length") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SkillTextRead{ + SkillTextRead: &gatewayv1.SkillTextReadRequest{ + Path: path, + Offset: offset, + Length: length, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetSkillTextReadResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "content": resp.GetContent(), + "truncated": resp.GetTruncated(), + }) +} + +func (c *websocketConnection) handleSkillManage(req websocketRequest) { + payloadJSON := strings.TrimSpace(string(req.Payload)) + if payloadJSON == "" || payloadJSON == "null" { + payloadJSON = "{}" + } + if !json.Valid([]byte(payloadJSON)) { + _ = c.writeError(req.ID, "invalid skills.manage payload") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_SkillManage{ + SkillManage: &gatewayv1.SkillManageRequest{ + PayloadJson: payloadJSON, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetSkillManageResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + var payload any + raw := strings.TrimSpace(resp.GetResultJson()) + if raw == "" { + payload = map[string]any{} + } else if err := json.Unmarshal([]byte(raw), &payload); err != nil { + _ = c.writeError(req.ID, "skill manage response is not valid JSON") + return + } + + _ = c.writeResponse(req.ID, payload) +} diff --git a/crates/agent-gateway/internal/server/websocket_terminal_handlers.go b/crates/agent-gateway/internal/server/websocket_terminal_handlers.go new file mode 100644 index 000000000..fd32b7e1a --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_terminal_handlers.go @@ -0,0 +1,117 @@ +package server + +import ( + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func terminalActionFromRequestType(requestType string) string { + return strings.TrimPrefix(strings.TrimSpace(requestType), "terminal.") +} + +func (c *websocketConnection) handleTerminalRequest(req websocketRequest) { + action := terminalActionFromRequestType(req.Type) + if !c.sm.WebTerminalEnabled() { + _ = c.writeError(req.ID, "web terminal is disabled in desktop Remote settings") + return + } + + var body websocketTerminalRequestPayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid "+req.Type+" payload") + return + } + + cols, err := websocketOptionalUint32(body.Cols, "cols") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + rows, err := websocketOptionalUint32(body.Rows, "rows") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + maxBytes, err := websocketOptionalUint32(body.MaxBytes, "max_bytes") + if err != nil { + _ = c.writeError(req.ID, err.Error()) + return + } + projectPathKey := strings.TrimSpace(body.ProjectPathKey) + if action == "attach" || action == "snapshot" { + c.rememberTerminalSession(body.SessionID, projectPathKey) + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_TerminalRequest{ + TerminalRequest: &gatewayv1.TerminalRequest{ + Action: action, + SessionId: strings.TrimSpace(body.SessionID), + ProjectPathKey: projectPathKey, + Cwd: strings.TrimSpace(body.Cwd), + Shell: strings.TrimSpace(body.Shell), + Title: strings.TrimSpace(body.Title), + Data: body.Data, + Cols: cols, + Rows: rows, + MaxBytes: maxBytes, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetTerminalResponse() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + c.sm.ApplyTerminalResponseSnapshot(action, projectPathKey, resp) + c.rememberTerminalInterest(action, body, resp) + + _ = c.writeResponse(req.ID, websocketTerminalResponsePayload(resp)) +} + +func (c *websocketConnection) rememberTerminalInterest(action string, body websocketTerminalRequestPayload, resp *gatewayv1.TerminalResponse) { + projectPathKey := strings.TrimSpace(body.ProjectPathKey) + sessionID := strings.TrimSpace(body.SessionID) + if respSession := resp.GetSession(); respSession != nil { + if projectPathKey == "" { + projectPathKey = strings.TrimSpace(respSession.GetProjectPathKey()) + } + if sessionID == "" { + sessionID = strings.TrimSpace(respSession.GetId()) + } + } + + switch action { + case "list", "create", "close_project": + c.rememberTerminalProject(projectPathKey) + case "attach", "snapshot": + c.rememberTerminalSession(sessionID, projectPathKey) + } +} + +func (c *websocketConnection) handleTerminalDetach(req websocketRequest) { + if !c.sm.WebTerminalEnabled() { + _ = c.writeError(req.ID, "web terminal is disabled in desktop Remote settings") + return + } + var body websocketTerminalRequestPayload + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid terminal.detach payload") + return + } + c.forgetTerminalInterest(body.SessionID, body.ProjectPathKey) + _ = c.writeResponse(req.ID, map[string]any{"action": "detach"}) +} diff --git a/crates/agent-gateway/internal/server/websocket_upload_handlers.go b/crates/agent-gateway/internal/server/websocket_upload_handlers.go new file mode 100644 index 000000000..ce55aecd0 --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_upload_handlers.go @@ -0,0 +1,57 @@ +package server + +import ( + "strings" + "time" + + "github.com/liveagent/agent-gateway/internal/handler" + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (c *websocketConnection) handleUploadedImagePreview(req websocketRequest) { + var body handler.UploadedImagePreviewRequestBody + if err := decodeWebSocketPayload(req.Payload, &body); err != nil { + _ = c.writeError(req.ID, "invalid files.preview payload") + return + } + body.Workdir = strings.TrimSpace(body.Workdir) + body.AbsolutePath = strings.TrimSpace(body.AbsolutePath) + if body.Workdir == "" { + _ = c.writeError(req.ID, "workdir is required") + return + } + if body.AbsolutePath == "" { + _ = c.writeError(req.ID, "absolute_path is required") + return + } + + response, err := c.awaitAgentResponse(req.ID, &gatewayv1.GatewayEnvelope{ + RequestId: req.ID, + Timestamp: time.Now().Unix(), + Payload: &gatewayv1.GatewayEnvelope_UploadedImagePreview{ + UploadedImagePreview: &gatewayv1.UploadedImagePreviewRequest{ + Workdir: body.Workdir, + AbsolutePath: body.AbsolutePath, + }, + }, + }) + if err != nil { + _ = c.writeError(req.ID, websocketErrorMessage(err)) + return + } + if errResp := response.GetError(); errResp != nil { + _ = c.writeError(req.ID, errResp.GetMessage()) + return + } + + resp := response.GetUploadedImagePreviewResp() + if resp == nil { + _ = c.writeError(req.ID, "unexpected agent response") + return + } + + _ = c.writeResponse(req.ID, map[string]any{ + "mimeType": resp.GetMimeType(), + "data": resp.GetData(), + }) +} diff --git a/crates/agent-gateway/internal/server/websocket_writer.go b/crates/agent-gateway/internal/server/websocket_writer.go new file mode 100644 index 000000000..61474395d --- /dev/null +++ b/crates/agent-gateway/internal/server/websocket_writer.go @@ -0,0 +1,35 @@ +package server + +import ( + "sync" + "time" + + "golang.org/x/net/websocket" +) + +type websocketConnectionWriter struct { + conn *websocket.Conn + timeout time.Duration + mu sync.Mutex +} + +func newWebsocketConnectionWriter(conn *websocket.Conn, timeout time.Duration) *websocketConnectionWriter { + return &websocketConnectionWriter{ + conn: conn, + timeout: timeout, + } +} + +func (w *websocketConnectionWriter) write(envelope websocketEnvelope) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.timeout > 0 { + if err := w.conn.SetWriteDeadline(time.Now().Add(w.timeout)); err != nil { + return err + } + defer func() { + _ = w.conn.SetWriteDeadline(time.Time{}) + }() + } + return websocket.JSON.Send(w.conn, envelope) +} diff --git a/crates/agent-gateway/internal/session/agent_session.go b/crates/agent-gateway/internal/session/agent_session.go new file mode 100644 index 000000000..790e6ae90 --- /dev/null +++ b/crates/agent-gateway/internal/session/agent_session.go @@ -0,0 +1,134 @@ +package session + +import ( + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func NewAgentSession(auth AuthSnapshot) *AgentSession { + return &AgentSession{ + AgentID: auth.AgentID, + AgentVersion: auth.AgentVersion, + SessionID: auth.SessionID, + ConnectedAt: time.Now(), + LastPing: time.Now(), + toAgent: make(chan *gatewayv1.GatewayEnvelope, 64), + done: make(chan struct{}), + streams: make(map[string]*agentStream), + } +} + +func (s *AgentSession) Outbound() <-chan *gatewayv1.GatewayEnvelope { + return s.toAgent +} + +func (s *AgentSession) Done() <-chan struct{} { + return s.done +} + +func (s *AgentSession) Close() { + s.closeOnce.Do(func() { + s.streamsMu.Lock() + s.closed = true + close(s.done) + for requestID, stream := range s.streams { + delete(s.streams, requestID) + stream.close() + } + s.streamsMu.Unlock() + }) +} + +func (s *AgentSession) SendToAgent(env *gatewayv1.GatewayEnvelope) error { + s.streamsMu.Lock() + closed := s.closed + s.streamsMu.Unlock() + if closed { + return ErrAgentOffline + } + + select { + case <-s.done: + return ErrAgentOffline + case s.toAgent <- env: + return nil + } +} + +func (s *AgentSession) TrySendToAgent(env *gatewayv1.GatewayEnvelope) (bool, error) { + s.streamsMu.Lock() + closed := s.closed + s.streamsMu.Unlock() + if closed { + return false, ErrAgentOffline + } + + select { + case <-s.done: + return false, ErrAgentOffline + default: + } + + select { + case <-s.done: + return false, ErrAgentOffline + case s.toAgent <- env: + return true, nil + default: + return false, nil + } +} + +func (s *AgentSession) registerStream(requestID string) (*agentStream, error) { + stream := &agentStream{ + ch: make(chan *gatewayv1.AgentEnvelope, 64), + done: make(chan struct{}), + } + + s.streamsMu.Lock() + defer s.streamsMu.Unlock() + if s.closed { + stream.close() + return nil, ErrAgentOffline + } + if existing, ok := s.streams[requestID]; ok { + existing.close() + } + s.streams[requestID] = stream + return stream, nil +} + +func (s *AgentSession) unregisterStream(requestID string, stream *agentStream) { + s.streamsMu.Lock() + if existing, ok := s.streams[requestID]; ok && existing == stream { + delete(s.streams, requestID) + existing.close() + } + s.streamsMu.Unlock() +} + +func (s *AgentSession) dispatch(env *gatewayv1.AgentEnvelope) { + s.streamsMu.Lock() + stream := s.streams[env.GetRequestId()] + s.streamsMu.Unlock() + if stream == nil { + return + } + stream.send(env) +} + +func (s *agentStream) close() { + s.closeOnce.Do(func() { + close(s.done) + }) +} + +func (s *agentStream) send(env *gatewayv1.AgentEnvelope) bool { + select { + case <-s.done: + return false + case s.ch <- env: + return true + } +} diff --git a/crates/agent-gateway/internal/session/manager.go b/crates/agent-gateway/internal/session/manager.go index 5db1f882c..c46bf6b1a 100644 --- a/crates/agent-gateway/internal/session/manager.go +++ b/crates/agent-gateway/internal/session/manager.go @@ -1,10 +1,7 @@ package session import ( - "encoding/json" "errors" - "sort" - "strings" "sync" "time" @@ -29,35 +26,9 @@ type AuthSnapshot struct { } type Manager struct { - mu sync.RWMutex - session *AgentSession - sessionEpoch uint64 - lastAuth AuthSnapshot - authValid bool - - historyMu sync.Mutex - nextHistorySubID int - historySubscribers map[int]chan *gatewayv1.HistorySyncEvent - - settingsMu sync.Mutex - nextSettingsSubID int - settingsSubscribers map[int]chan *gatewayv1.SettingsSyncEvent - settingsSnapshotMu sync.RWMutex - settingsSnapshot map[string]any - - terminalMu sync.Mutex - nextTerminalSubID int - terminalSubscribers map[int]chan *gatewayv1.TerminalEvent - terminalSessions map[string]*gatewayv1.TerminalSession - - chatMu sync.Mutex - nextChatSubID int - chatSubscribers map[int]chan *ChatBroadcastEvent - nextChatRunSubID int - chatRuns map[string]*chatRun - chatRunByConversation map[string]string - chatRunByClientRequest map[string]string - historyActiveRuns map[string]activeHistoryRun + registry *sessionRegistry + syncHub *syncHub + chatStore *chatRunStore } type AgentSession struct { @@ -143,1287 +114,8 @@ type Status struct { func NewManager() *Manager { return &Manager{ - historySubscribers: make(map[int]chan *gatewayv1.HistorySyncEvent), - settingsSubscribers: make(map[int]chan *gatewayv1.SettingsSyncEvent), - terminalSubscribers: make(map[int]chan *gatewayv1.TerminalEvent), - terminalSessions: make(map[string]*gatewayv1.TerminalSession), - chatSubscribers: make(map[int]chan *ChatBroadcastEvent), - chatRuns: make(map[string]*chatRun), - chatRunByConversation: make(map[string]string), - chatRunByClientRequest: make(map[string]string), - historyActiveRuns: make(map[string]activeHistoryRun), - } -} - -func NewAgentSession(auth AuthSnapshot) *AgentSession { - return &AgentSession{ - AgentID: auth.AgentID, - AgentVersion: auth.AgentVersion, - SessionID: auth.SessionID, - ConnectedAt: time.Now(), - LastPing: time.Now(), - toAgent: make(chan *gatewayv1.GatewayEnvelope, 64), - done: make(chan struct{}), - streams: make(map[string]*agentStream), - } -} - -func (s *AgentSession) Outbound() <-chan *gatewayv1.GatewayEnvelope { - return s.toAgent -} - -func (s *AgentSession) Done() <-chan struct{} { - return s.done -} - -func (s *AgentSession) Close() { - s.closeOnce.Do(func() { - s.streamsMu.Lock() - s.closed = true - close(s.done) - for requestID, stream := range s.streams { - delete(s.streams, requestID) - stream.close() - } - s.streamsMu.Unlock() - }) -} - -func (s *AgentSession) SendToAgent(env *gatewayv1.GatewayEnvelope) error { - s.streamsMu.Lock() - closed := s.closed - s.streamsMu.Unlock() - if closed { - return ErrAgentOffline - } - - select { - case <-s.done: - return ErrAgentOffline - case s.toAgent <- env: - return nil - } -} - -func (s *AgentSession) TrySendToAgent(env *gatewayv1.GatewayEnvelope) (bool, error) { - s.streamsMu.Lock() - closed := s.closed - s.streamsMu.Unlock() - if closed { - return false, ErrAgentOffline - } - - select { - case <-s.done: - return false, ErrAgentOffline - default: - } - - select { - case <-s.done: - return false, ErrAgentOffline - case s.toAgent <- env: - return true, nil - default: - return false, nil - } -} - -func (s *AgentSession) registerStream(requestID string) (*agentStream, error) { - stream := &agentStream{ - ch: make(chan *gatewayv1.AgentEnvelope, 64), - done: make(chan struct{}), - } - - s.streamsMu.Lock() - defer s.streamsMu.Unlock() - if s.closed { - stream.close() - return nil, ErrAgentOffline - } - if existing, ok := s.streams[requestID]; ok { - existing.close() - } - s.streams[requestID] = stream - return stream, nil -} - -func (s *AgentSession) unregisterStream(requestID string, stream *agentStream) { - s.streamsMu.Lock() - if existing, ok := s.streams[requestID]; ok && existing == stream { - delete(s.streams, requestID) - existing.close() - } - s.streamsMu.Unlock() -} - -func (s *AgentSession) dispatch(env *gatewayv1.AgentEnvelope) { - s.streamsMu.Lock() - stream := s.streams[env.GetRequestId()] - s.streamsMu.Unlock() - if stream == nil { - return - } - stream.send(env) -} - -func (s *agentStream) close() { - s.closeOnce.Do(func() { - close(s.done) - }) -} - -func (s *agentStream) send(env *gatewayv1.AgentEnvelope) bool { - select { - case <-s.done: - return false - case s.ch <- env: - return true - } -} - -func (m *Manager) RecordAuthentication(agentID, agentVersion, sessionID string) { - m.mu.Lock() - defer m.mu.Unlock() - m.lastAuth = AuthSnapshot{ - AgentID: agentID, - AgentVersion: agentVersion, - SessionID: sessionID, - } - m.authValid = true -} - -func (m *Manager) LatestAuthSnapshot() AuthSnapshot { - m.mu.RLock() - defer m.mu.RUnlock() - return m.lastAuth -} - -func (m *Manager) IsOnline() bool { - m.mu.RLock() - defer m.mu.RUnlock() - return m.session != nil -} - -func (m *Manager) SetSession(s *AgentSession) { - m.mu.Lock() - previous := m.session - previousEpoch := m.sessionEpoch - if m.authValid { - s.AgentID = m.lastAuth.AgentID - s.AgentVersion = m.lastAuth.AgentVersion - s.SessionID = m.lastAuth.SessionID - } - if previous != s { - m.sessionEpoch += 1 - } - sessionChanged := previous != s - m.session = s - m.mu.Unlock() - - if sessionChanged { - m.clearTerminalSessionSnapshot() - } - if previous != nil && previous != s { - previous.Close() - m.failOpenChatRunsForSessionEpoch(previousEpoch, agentDisconnectedChatRunMessage) - } -} - -func (m *Manager) ClearSession(session *AgentSession) { - m.mu.Lock() - if m.session != session { - m.mu.Unlock() - return - } - clearedEpoch := m.sessionEpoch - m.session = nil - m.mu.Unlock() - - if session == nil { - return - } - - session.Close() - m.clearTerminalSessionSnapshot() - m.failOpenChatRunsForSessionEpoch(clearedEpoch, agentDisconnectedChatRunMessage) -} - -func (m *Manager) Status() Status { - m.mu.RLock() - defer m.mu.RUnlock() - - status := Status{} - if m.authValid { - status.AgentID = m.lastAuth.AgentID - status.AgentVersion = m.lastAuth.AgentVersion - status.SessionID = m.lastAuth.SessionID - } - if m.session == nil { - return status - } - status.Online = true - status.AgentID = m.session.AgentID - status.AgentVersion = m.session.AgentVersion - status.SessionID = m.session.SessionID - status.ConnectedSince = m.session.ConnectedAt.Unix() - status.LastHeartbeat = m.session.LastPing.Unix() - return status -} - -func (m *Manager) TouchHeartbeat(session *AgentSession) { - m.mu.Lock() - defer m.mu.Unlock() - if m.session == session { - m.session.LastPing = time.Now() - } -} - -func (m *Manager) SendToAgent(env *gatewayv1.GatewayEnvelope) error { - m.mu.RLock() - session := m.session - m.mu.RUnlock() - if session == nil { - return ErrAgentOffline - } - - return session.SendToAgent(env) -} - -func (m *Manager) currentSessionEpoch() uint64 { - m.mu.RLock() - defer m.mu.RUnlock() - return m.sessionEpoch -} - -func (m *Manager) RegisterStream(requestID string) (<-chan *gatewayv1.AgentEnvelope, <-chan struct{}, func(), error) { - m.mu.RLock() - session := m.session - m.mu.RUnlock() - if session == nil { - return nil, nil, nil, ErrAgentOffline - } - - stream, err := session.registerStream(requestID) - if err != nil { - return nil, nil, nil, err - } - - cleanup := func() { - session.unregisterStream(requestID, stream) - } - - return stream.ch, stream.done, cleanup, nil -} - -func (m *Manager) SubscribeHistorySync() (<-chan *gatewayv1.HistorySyncEvent, func()) { - ch := make(chan *gatewayv1.HistorySyncEvent, 32) - - m.historyMu.Lock() - subID := m.nextHistorySubID - m.nextHistorySubID += 1 - m.historySubscribers[subID] = ch - m.historyMu.Unlock() - - cleanup := func() { - m.historyMu.Lock() - existing, ok := m.historySubscribers[subID] - if ok { - delete(m.historySubscribers, subID) - close(existing) - } - m.historyMu.Unlock() - } - - return ch, cleanup -} - -func (m *Manager) broadcastHistorySync(event *gatewayv1.HistorySyncEvent) { - if event == nil { - return - } - - m.updateActiveHistoryRun(event) - m.releaseCompletedChatRunAfterHistoryUpsert(event) - - m.historyMu.Lock() - subscribers := make([]chan *gatewayv1.HistorySyncEvent, 0, len(m.historySubscribers)) - for _, ch := range m.historySubscribers { - subscribers = append(subscribers, ch) - } - m.historyMu.Unlock() - - for _, ch := range subscribers { - select { - case ch <- event: - default: - } - } -} - -func historySyncConversationID(event *gatewayv1.HistorySyncEvent) string { - conversationID := strings.TrimSpace(event.GetConversationId()) - if conversationID == "" && event.GetConversation() != nil { - conversationID = strings.TrimSpace(event.GetConversation().GetId()) - } - return conversationID -} - -func historySyncWorkdir(event *gatewayv1.HistorySyncEvent) string { - if event == nil || event.GetConversation() == nil { - return "" - } - return strings.TrimSpace(event.GetConversation().GetCwd()) -} - -func (m *Manager) updateActiveHistoryRun(event *gatewayv1.HistorySyncEvent) { - kind := strings.TrimSpace(event.GetKind()) - conversationID := historySyncConversationID(event) - if conversationID == "" { - return - } - - workdir := historySyncWorkdir(event) - now := time.Now() - - m.chatMu.Lock() - defer m.chatMu.Unlock() - m.pruneExpiredChatRunsLocked(now) - - switch kind { - case "running": - existing := m.historyActiveRuns[conversationID] - if workdir == "" { - workdir = existing.workdir - } - m.historyActiveRuns[conversationID] = activeHistoryRun{ - conversationID: conversationID, - workdir: workdir, - updatedAt: now, - } - if requestID := m.chatRunByConversation[conversationID]; requestID != "" { - if run := m.chatRuns[requestID]; run != nil && workdir != "" { - run.workdir = workdir - } - } - case "idle", "delete": - delete(m.historyActiveRuns, conversationID) - case "upsert": - if workdir == "" { - return - } - if existing, ok := m.historyActiveRuns[conversationID]; ok { - existing.workdir = workdir - existing.updatedAt = now - m.historyActiveRuns[conversationID] = existing - } - if requestID := m.chatRunByConversation[conversationID]; requestID != "" { - if run := m.chatRuns[requestID]; run != nil { - run.workdir = workdir - } - } - } -} - -func (m *Manager) releaseCompletedChatRunAfterHistoryUpsert(event *gatewayv1.HistorySyncEvent) { - if strings.TrimSpace(event.GetKind()) != "upsert" { - return - } - - conversationID := historySyncConversationID(event) - if conversationID == "" { - return - } - - m.chatMu.Lock() - defer m.chatMu.Unlock() - requestID := m.chatRunByConversation[conversationID] - run := m.chatRuns[requestID] - if run == nil || !run.done { - return - } - m.releaseCompletedChatRunLocked(requestID, run) -} - -func (m *Manager) SubscribeSettingsSync() (<-chan *gatewayv1.SettingsSyncEvent, func()) { - ch := make(chan *gatewayv1.SettingsSyncEvent, 32) - - m.settingsMu.Lock() - subID := m.nextSettingsSubID - m.nextSettingsSubID += 1 - m.settingsSubscribers[subID] = ch - m.settingsMu.Unlock() - - cleanup := func() { - m.settingsMu.Lock() - existing, ok := m.settingsSubscribers[subID] - if ok { - delete(m.settingsSubscribers, subID) - close(existing) - } - m.settingsMu.Unlock() - } - - return ch, cleanup -} - -func (m *Manager) WebTerminalEnabled() bool { - m.settingsSnapshotMu.RLock() - defer m.settingsSnapshotMu.RUnlock() - - remote, ok := m.settingsSnapshot["remote"].(map[string]any) - if !ok { - return false - } - enabled, ok := remote["enableWebTerminal"].(bool) - return ok && enabled -} - -func (m *Manager) WebGitEnabled() bool { - m.settingsSnapshotMu.RLock() - defer m.settingsSnapshotMu.RUnlock() - - remote, ok := m.settingsSnapshot["remote"].(map[string]any) - if !ok { - return false - } - enabled, ok := remote["enableWebGit"].(bool) - return ok && enabled -} - -func (m *Manager) updateSettingsSnapshot(event *gatewayv1.SettingsSyncEvent) { - if event == nil { - return - } - m.ApplySettingsJSON(event.GetSettingsJson()) -} - -func parseSettingsJSON(settingsJSON string) (map[string]any, bool) { - raw := strings.TrimSpace(settingsJSON) - if raw == "" { - return nil, false - } - var payload map[string]any - if err := json.Unmarshal([]byte(raw), &payload); err != nil || payload == nil { - return nil, false - } - return payload, true -} - -func (m *Manager) ApplySettingsJSON(settingsJSON string) { - payload, ok := parseSettingsJSON(settingsJSON) - if !ok { - return - } - m.settingsSnapshotMu.Lock() - if _, hasIncomingRemote := payload["remote"]; !hasIncomingRemote { - if existingRemote, hasExistingRemote := m.settingsSnapshot["remote"]; hasExistingRemote { - payload["remote"] = existingRemote - } - } - m.settingsSnapshot = payload - m.settingsSnapshotMu.Unlock() -} - -func (m *Manager) ApplySettingsJSONPreservingRemote(settingsJSON string) { - payload, ok := parseSettingsJSON(settingsJSON) - if !ok { - return - } - m.settingsSnapshotMu.Lock() - if existingRemote, ok := m.settingsSnapshot["remote"]; ok { - payload["remote"] = existingRemote - } else { - delete(payload, "remote") - } - m.settingsSnapshot = payload - m.settingsSnapshotMu.Unlock() -} - -func (m *Manager) broadcastSettingsSync(event *gatewayv1.SettingsSyncEvent) { - if event == nil { - return - } - m.updateSettingsSnapshot(event) - - m.settingsMu.Lock() - subscribers := make([]chan *gatewayv1.SettingsSyncEvent, 0, len(m.settingsSubscribers)) - for _, ch := range m.settingsSubscribers { - subscribers = append(subscribers, ch) - } - m.settingsMu.Unlock() - - for _, ch := range subscribers { - select { - case ch <- event: - default: - } - } -} - -func (m *Manager) SubscribeTerminalEvents() (<-chan *gatewayv1.TerminalEvent, func()) { - ch := make(chan *gatewayv1.TerminalEvent, 4096) - - m.terminalMu.Lock() - subID := m.nextTerminalSubID - m.nextTerminalSubID += 1 - m.terminalSubscribers[subID] = ch - m.terminalMu.Unlock() - - cleanup := func() { - m.terminalMu.Lock() - existing, ok := m.terminalSubscribers[subID] - if ok { - delete(m.terminalSubscribers, subID) - close(existing) - } - m.terminalMu.Unlock() - } - - return ch, cleanup -} - -func cloneTerminalSession(session *gatewayv1.TerminalSession) *gatewayv1.TerminalSession { - if session == nil { - return nil - } - return &gatewayv1.TerminalSession{ - Id: session.GetId(), - ProjectPathKey: session.GetProjectPathKey(), - Cwd: session.GetCwd(), - Shell: session.GetShell(), - Title: session.GetTitle(), - Pid: session.GetPid(), - Cols: session.GetCols(), - Rows: session.GetRows(), - CreatedAt: session.GetCreatedAt(), - UpdatedAt: session.GetUpdatedAt(), - FinishedAt: session.GetFinishedAt(), - ExitCode: session.GetExitCode(), - Running: session.GetRunning(), - } -} - -func terminalSessionSortKey(session *gatewayv1.TerminalSession) (string, uint64, string) { - if session == nil { - return "", 0, "" - } - return strings.TrimSpace(session.GetProjectPathKey()), session.GetCreatedAt(), strings.TrimSpace(session.GetId()) -} - -func sortTerminalSessions(sessions []*gatewayv1.TerminalSession) { - sort.Slice(sessions, func(i, j int) bool { - leftProject, leftCreatedAt, leftID := terminalSessionSortKey(sessions[i]) - rightProject, rightCreatedAt, rightID := terminalSessionSortKey(sessions[j]) - if leftProject != rightProject { - return leftProject < rightProject - } - if leftCreatedAt != rightCreatedAt { - return leftCreatedAt < rightCreatedAt - } - return leftID < rightID - }) -} - -func terminalSessionMatchesProject(session *gatewayv1.TerminalSession, projectPathKey string) bool { - projectPathKey = strings.TrimSpace(projectPathKey) - if projectPathKey == "" { - return true - } - if session == nil { - return false - } - return strings.TrimSpace(session.GetProjectPathKey()) == projectPathKey -} - -func (m *Manager) clearTerminalSessionSnapshot() { - m.terminalMu.Lock() - m.terminalSessions = make(map[string]*gatewayv1.TerminalSession) - m.terminalMu.Unlock() -} - -func (m *Manager) TerminalSessionSnapshot(projectPathKey string) []*gatewayv1.TerminalSession { - projectPathKey = strings.TrimSpace(projectPathKey) - m.terminalMu.Lock() - sessions := make([]*gatewayv1.TerminalSession, 0, len(m.terminalSessions)) - for _, session := range m.terminalSessions { - if !terminalSessionMatchesProject(session, projectPathKey) { - continue - } - if cloned := cloneTerminalSession(session); cloned != nil { - sessions = append(sessions, cloned) - } - } - m.terminalMu.Unlock() - sortTerminalSessions(sessions) - return sessions -} - -func (m *Manager) ReplaceTerminalSessionSnapshot( - projectPathKey string, - sessions []*gatewayv1.TerminalSession, -) { - projectPathKey = strings.TrimSpace(projectPathKey) - m.terminalMu.Lock() - if projectPathKey == "" { - m.terminalSessions = make(map[string]*gatewayv1.TerminalSession) - } else { - for id, session := range m.terminalSessions { - if terminalSessionMatchesProject(session, projectPathKey) { - delete(m.terminalSessions, id) - } - } - } - for _, session := range sessions { - id := strings.TrimSpace(session.GetId()) - if id == "" { - continue - } - m.terminalSessions[id] = cloneTerminalSession(session) - } - m.terminalMu.Unlock() -} - -func (m *Manager) ApplyTerminalResponseSnapshot( - action string, - projectPathKey string, - resp *gatewayv1.TerminalResponse, -) { - if resp == nil { - return - } - action = strings.TrimSpace(action) - projectPathKey = strings.TrimSpace(projectPathKey) - - switch action { - case "list": - m.ReplaceTerminalSessionSnapshot(projectPathKey, resp.GetSessions()) - case "close_project": - m.ReplaceTerminalSessionSnapshot(projectPathKey, nil) - case "close": - if sessionID := strings.TrimSpace(resp.GetSession().GetId()); sessionID != "" { - m.terminalMu.Lock() - delete(m.terminalSessions, sessionID) - m.terminalMu.Unlock() - } - case "create", "attach", "snapshot", "input", "resize", "rename": - session := resp.GetSession() - sessionID := strings.TrimSpace(session.GetId()) - if sessionID == "" { - return - } - m.terminalMu.Lock() - m.terminalSessions[sessionID] = cloneTerminalSession(session) - m.terminalMu.Unlock() - } -} - -func (m *Manager) applyTerminalEventSnapshot(event *gatewayv1.TerminalEvent) { - if event == nil { - return - } - kind := strings.TrimSpace(event.GetKind()) - sessionID := strings.TrimSpace(event.GetSessionId()) - if sessionID == "" && event.GetSession() != nil { - sessionID = strings.TrimSpace(event.GetSession().GetId()) - } - if sessionID == "" { - return - } - - m.terminalMu.Lock() - if kind == "closed" { - delete(m.terminalSessions, sessionID) - } else if session := cloneTerminalSession(event.GetSession()); session != nil { - m.terminalSessions[sessionID] = session - } - m.terminalMu.Unlock() -} - -func (m *Manager) broadcastTerminalEvent(event *gatewayv1.TerminalEvent) { - if event == nil { - return - } - - m.applyTerminalEventSnapshot(event) - - m.terminalMu.Lock() - subscribers := make([]chan *gatewayv1.TerminalEvent, 0, len(m.terminalSubscribers)) - for _, ch := range m.terminalSubscribers { - subscribers = append(subscribers, ch) - } - m.terminalMu.Unlock() - - for _, ch := range subscribers { - select { - case ch <- event: - case <-time.After(50 * time.Millisecond): - } - } -} - -func (m *Manager) SubscribeChatEvents() (<-chan *ChatBroadcastEvent, func()) { - ch := make(chan *ChatBroadcastEvent, 128) - - m.chatMu.Lock() - subID := m.nextChatSubID - m.nextChatSubID += 1 - m.chatSubscribers[subID] = ch - m.chatMu.Unlock() - - cleanup := func() { - m.chatMu.Lock() - existing, ok := m.chatSubscribers[subID] - if ok { - delete(m.chatSubscribers, subID) - close(existing) - } - m.chatMu.Unlock() - } - - return ch, cleanup -} - -func (m *Manager) StartChatRun(requestID string, conversationID string) (ChatRunSnapshot, error) { - snapshot, _, err := m.StartChatRunWithClientRequest(requestID, conversationID, "", "") - return snapshot, err -} - -func (m *Manager) StartChatRunWithClientRequest( - requestID string, - conversationID string, - clientRequestID string, - workdirInput ...string, -) (ChatRunSnapshot, bool, error) { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return ChatRunSnapshot{}, false, ErrChatRunNotFound - } - - now := time.Now() - conversationID = strings.TrimSpace(conversationID) - clientRequestID = strings.TrimSpace(clientRequestID) - workdir := "" - if len(workdirInput) > 0 { - workdir = strings.TrimSpace(workdirInput[0]) - } - sessionEpoch := m.currentSessionEpoch() - - m.chatMu.Lock() - defer m.chatMu.Unlock() - m.pruneExpiredChatRunsLocked(now) - - if clientRequestID != "" { - if existingRequestID := m.chatRunByClientRequest[clientRequestID]; existingRequestID != "" { - if existing := m.chatRuns[existingRequestID]; existing != nil { - if !existing.done { - if workdir != "" && existing.workdir == "" { - existing.workdir = workdir - } - return existing.snapshot(), false, nil - } - m.releaseCompletedChatRunLocked(existingRequestID, existing) - } - delete(m.chatRunByClientRequest, clientRequestID) - } - } - - if existing := m.chatRuns[requestID]; existing != nil { - m.removeChatRunLocked(requestID, existing) - } - - run := &chatRun{ - requestID: requestID, - conversationID: conversationID, - clientRequestID: clientRequestID, - workdir: workdir, - sessionEpoch: sessionEpoch, - updatedAt: now, - subscribers: make(map[int]*chatRunSubscriber), - } - m.chatRuns[requestID] = run - if conversationID != "" { - m.chatRunByConversation[conversationID] = requestID - } - if clientRequestID != "" { - m.chatRunByClientRequest[clientRequestID] = requestID - } - - return run.snapshot(), true, nil -} - -func (m *Manager) RemoveChatRun(requestID string) { - requestID = strings.TrimSpace(requestID) - if requestID == "" { - return - } - - m.chatMu.Lock() - defer m.chatMu.Unlock() - run := m.chatRuns[requestID] - if run == nil { - return - } - m.removeChatRunLocked(requestID, run) -} - -func (m *Manager) RemoveChatRunByConversation(conversationID string) { - conversationID = strings.TrimSpace(conversationID) - if conversationID == "" { - return - } - - m.chatMu.Lock() - defer m.chatMu.Unlock() - requestID := m.chatRunByConversation[conversationID] - run := m.chatRuns[requestID] - if run == nil { - for candidateRequestID, candidateRun := range m.chatRuns { - if strings.TrimSpace(candidateRun.conversationID) == conversationID { - requestID = candidateRequestID - run = candidateRun - break - } - } - } - if run == nil { - return - } - m.removeChatRunLocked(requestID, run) -} - -func (m *Manager) ActiveChatRunSummaries() []ActiveChatRunSummary { - m.chatMu.Lock() - defer m.chatMu.Unlock() - now := time.Now() - m.pruneExpiredChatRunsLocked(now) - - seen := make(map[string]int, len(m.chatRuns)+len(m.historyActiveRuns)) - summaries := make([]ActiveChatRunSummary, 0, len(m.chatRuns)+len(m.historyActiveRuns)) - for _, run := range m.chatRuns { - if run == nil || run.done { - continue - } - conversationID := strings.TrimSpace(run.conversationID) - if conversationID == "" { - continue - } - summary := ActiveChatRunSummary{ - ConversationID: conversationID, - Workdir: strings.TrimSpace(run.workdir), - UpdatedAt: run.updatedAt.UnixMilli(), - } - if index, ok := seen[conversationID]; ok { - if summaries[index].Workdir == "" { - summaries[index].Workdir = summary.Workdir - } - if summary.UpdatedAt > summaries[index].UpdatedAt { - summaries[index].UpdatedAt = summary.UpdatedAt - } - continue - } - seen[conversationID] = len(summaries) - summaries = append(summaries, summary) - } - - for conversationID, run := range m.historyActiveRuns { - conversationID = strings.TrimSpace(conversationID) - if conversationID == "" { - continue - } - if now.Sub(run.updatedAt) > chatRunStaleRetention { - delete(m.historyActiveRuns, conversationID) - continue - } - workdir := strings.TrimSpace(run.workdir) - updatedAt := run.updatedAt.UnixMilli() - if index, ok := seen[conversationID]; ok { - if summaries[index].Workdir == "" { - summaries[index].Workdir = workdir - } - if updatedAt > summaries[index].UpdatedAt { - summaries[index].UpdatedAt = updatedAt - } - continue - } - seen[conversationID] = len(summaries) - summaries = append(summaries, ActiveChatRunSummary{ - ConversationID: conversationID, - Workdir: workdir, - UpdatedAt: updatedAt, - }) - } - - sort.Slice(summaries, func(i, j int) bool { - return summaries[i].ConversationID < summaries[j].ConversationID - }) - return summaries -} - -func (m *Manager) ActiveChatRunConversationIDs() []string { - summaries := m.ActiveChatRunSummaries() - ids := make([]string, 0, len(summaries)) - for _, summary := range summaries { - if conversationID := strings.TrimSpace(summary.ConversationID); conversationID != "" { - ids = append(ids, conversationID) - } - } - return ids -} - -func (m *Manager) failOpenChatRunsForSessionEpoch(sessionEpoch uint64, message string) { - message = strings.TrimSpace(message) - if message == "" { - message = agentDisconnectedChatRunMessage - } - - data, err := json.Marshal(map[string]string{"message": message}) - if err != nil { - data = []byte(`{"message":"Desktop agent disconnected. Please retry."}`) - } - now := time.Now() - - type broadcastTarget struct { - event *ChatBroadcastEvent - subscribers []*chatRunSubscriber - } - targets := make([]broadcastTarget, 0) - globalSubscribers := make([]chan *ChatBroadcastEvent, 0) - - m.chatMu.Lock() - m.pruneExpiredChatRunsLocked(now) - for requestID, run := range m.chatRuns { - if run == nil || run.done || run.sessionEpoch != sessionEpoch { - continue - } - - run.nextSeq += 1 - run.updatedAt = now - run.done = true - run.expiresAt = now.Add(chatRunDoneRetention) - - chatEvent := &gatewayv1.ChatEvent{ - Type: gatewayv1.ChatEvent_ERROR, - ConversationId: strings.TrimSpace(run.conversationID), - Data: string(data), - } - broadcast := &ChatBroadcastEvent{ - RequestID: requestID, - Event: chatEvent, - Seq: run.nextSeq, - Workdir: strings.TrimSpace(run.workdir), - } - run.events = append(run.events, cloneChatBroadcastEvent(broadcast)) - if len(run.events) > maxBufferedChatRunEvents { - copy(run.events, run.events[len(run.events)-maxBufferedChatRunEvents:]) - run.events = run.events[:maxBufferedChatRunEvents] - } - - subscribers := make([]*chatRunSubscriber, 0, len(run.subscribers)) - for _, subscriber := range run.subscribers { - subscribers = append(subscribers, subscriber) - } - targets = append(targets, broadcastTarget{ - event: broadcast, - subscribers: subscribers, - }) - } - for _, ch := range m.chatSubscribers { - globalSubscribers = append(globalSubscribers, ch) - } - m.chatMu.Unlock() - - for _, target := range targets { - for _, subscriber := range target.subscribers { - select { - case <-subscriber.done: - case subscriber.ch <- cloneChatBroadcastEvent(target.event): - } - } - for _, ch := range globalSubscribers { - select { - case ch <- cloneChatBroadcastEvent(target.event): - default: - } - } - } -} - -func (m *Manager) SubscribeChatRun( - requestID string, - conversationID string, - afterSeq int64, -) (<-chan *ChatBroadcastEvent, <-chan struct{}, func(), ChatRunSnapshot, error) { - requestID = strings.TrimSpace(requestID) - conversationID = strings.TrimSpace(conversationID) - if afterSeq < 0 { - afterSeq = 0 - } - - m.chatMu.Lock() - defer m.chatMu.Unlock() - m.pruneExpiredChatRunsLocked(time.Now()) - - if requestID == "" && conversationID != "" { - requestID = m.chatRunByConversation[conversationID] - } - run := m.chatRuns[requestID] - if run == nil { - done := make(chan struct{}) - close(done) - return nil, done, func() {}, ChatRunSnapshot{}, ErrChatRunNotFound - } - - replay := make([]*ChatBroadcastEvent, 0) - for _, event := range run.events { - if event.Seq > afterSeq { - replay = append(replay, cloneChatBroadcastEvent(event)) - } - } - - bufferSize := len(replay) + 128 - if bufferSize < 128 { - bufferSize = 128 - } - ch := make(chan *ChatBroadcastEvent, bufferSize) - done := make(chan struct{}) - for _, event := range replay { - ch <- event - } - - subID := -1 - var subscriber *chatRunSubscriber - if !run.done { - subID = m.nextChatRunSubID - m.nextChatRunSubID += 1 - subscriber = &chatRunSubscriber{ - ch: ch, - done: done, - } - run.subscribers[subID] = subscriber - } - - var cleanupOnce sync.Once - cleanup := func() { - cleanupOnce.Do(func() { - m.chatMu.Lock() - if subID >= 0 { - if current := m.chatRuns[requestID]; current != nil { - delete(current.subscribers, subID) - } - } - m.chatMu.Unlock() - if subscriber != nil { - subscriber.close() - } else { - close(done) - } - }) - } - - return ch, done, cleanup, run.snapshot(), nil -} - -func (m *Manager) broadcastChatEvent(requestID string, event *gatewayv1.ChatEvent) { - if event == nil { - return - } - - requestID = strings.TrimSpace(requestID) - conversationID := strings.TrimSpace(event.GetConversationId()) - now := time.Now() - sessionEpoch := m.currentSessionEpoch() - - m.chatMu.Lock() - m.pruneExpiredChatRunsLocked(now) - broadcast := &ChatBroadcastEvent{ - RequestID: requestID, - Event: event, - } - var runSubscribers []*chatRunSubscriber - run := m.chatRuns[requestID] - if run == nil && requestID != "" { - run = &chatRun{ - requestID: requestID, - conversationID: conversationID, - sessionEpoch: sessionEpoch, - updatedAt: now, - subscribers: make(map[int]*chatRunSubscriber), - } - m.chatRuns[requestID] = run - if conversationID != "" { - m.chatRunByConversation[conversationID] = requestID - } - } - if run != nil { - run.nextSeq += 1 - run.updatedAt = now - if conversationID != "" { - if run.conversationID != "" && run.conversationID != conversationID { - if m.chatRunByConversation[run.conversationID] == requestID { - delete(m.chatRunByConversation, run.conversationID) - } - } - run.conversationID = conversationID - m.chatRunByConversation[conversationID] = requestID - if run.workdir == "" { - if activeRun, ok := m.historyActiveRuns[conversationID]; ok { - run.workdir = strings.TrimSpace(activeRun.workdir) - } - } - } - broadcast.Seq = run.nextSeq - broadcast.Workdir = strings.TrimSpace(run.workdir) - run.events = append(run.events, cloneChatBroadcastEvent(broadcast)) - if len(run.events) > maxBufferedChatRunEvents { - copy(run.events, run.events[len(run.events)-maxBufferedChatRunEvents:]) - run.events = run.events[:maxBufferedChatRunEvents] - } - if isTerminalChatEvent(event) { - run.done = true - run.expiresAt = now.Add(chatRunDoneRetention) - } - runSubscribers = make([]*chatRunSubscriber, 0, len(run.subscribers)) - for _, subscriber := range run.subscribers { - runSubscribers = append(runSubscribers, subscriber) - } - } - subscribers := make([]chan *ChatBroadcastEvent, 0, len(m.chatSubscribers)) - for _, ch := range m.chatSubscribers { - subscribers = append(subscribers, ch) - } - m.chatMu.Unlock() - - for _, subscriber := range runSubscribers { - select { - case <-subscriber.done: - case subscriber.ch <- cloneChatBroadcastEvent(broadcast): - } - } - for _, ch := range subscribers { - select { - case ch <- broadcast: - default: - } - } -} - -func (m *Manager) DispatchFromAgent(env *gatewayv1.AgentEnvelope) { - m.dispatchFromAgent(nil, env) -} - -func (m *Manager) DispatchFromAgentForSession(session *AgentSession, env *gatewayv1.AgentEnvelope) { - m.dispatchFromAgent(session, env) -} - -func (m *Manager) dispatchFromAgent(expected *AgentSession, env *gatewayv1.AgentEnvelope) { - m.mu.RLock() - session := m.session - m.mu.RUnlock() - if session == nil || (expected != nil && session != expected) { - return - } - - if chatEvent := env.GetChatEvent(); chatEvent != nil { - m.broadcastChatEvent(env.GetRequestId(), chatEvent) - } - - if historySync := env.GetHistorySync(); historySync != nil { - m.broadcastHistorySync(historySync) - return - } - - if settingsSync := env.GetSettingsSync(); settingsSync != nil { - m.broadcastSettingsSync(settingsSync) - return - } - - if terminalEvent := env.GetTerminalEvent(); terminalEvent != nil { - m.broadcastTerminalEvent(terminalEvent) - return - } - - session.dispatch(env) -} - -func (r *chatRun) snapshot() ChatRunSnapshot { - firstSeq := int64(0) - if len(r.events) > 0 { - firstSeq = r.events[0].Seq - } - return ChatRunSnapshot{ - RequestID: r.requestID, - ConversationID: r.conversationID, - ClientRequestID: r.clientRequestID, - Workdir: r.workdir, - FirstSeq: firstSeq, - LatestSeq: r.nextSeq, - Done: r.done, - } -} - -func (s *chatRunSubscriber) close() { - s.closeOnce.Do(func() { - close(s.done) - }) -} - -func (m *Manager) pruneExpiredChatRunsLocked(now time.Time) { - for requestID, run := range m.chatRuns { - if run == nil { - delete(m.chatRuns, requestID) - continue - } - if run.done { - if !run.expiresAt.IsZero() && now.After(run.expiresAt) { - m.removeChatRunLocked(requestID, run) - } - continue - } - if !run.updatedAt.IsZero() && now.Sub(run.updatedAt) > chatRunStaleRetention { - m.removeChatRunLocked(requestID, run) - } - } -} - -func (m *Manager) removeChatRunLocked(requestID string, run *chatRun) { - if run.conversationID != "" && m.chatRunByConversation[run.conversationID] == requestID { - delete(m.chatRunByConversation, run.conversationID) - } - if run.clientRequestID != "" && m.chatRunByClientRequest[run.clientRequestID] == requestID { - delete(m.chatRunByClientRequest, run.clientRequestID) - } - delete(m.chatRuns, requestID) - for _, subscriber := range run.subscribers { - subscriber.close() - } -} - -func (m *Manager) releaseCompletedChatRunLocked(requestID string, run *chatRun) { - if run.conversationID != "" && m.chatRunByConversation[run.conversationID] == requestID { - delete(m.chatRunByConversation, run.conversationID) - } - if run.clientRequestID != "" && m.chatRunByClientRequest[run.clientRequestID] == requestID { - delete(m.chatRunByClientRequest, run.clientRequestID) - } - delete(m.chatRuns, requestID) -} - -func cloneChatBroadcastEvent(event *ChatBroadcastEvent) *ChatBroadcastEvent { - if event == nil { - return nil - } - return &ChatBroadcastEvent{ - RequestID: event.RequestID, - Event: event.Event, - Seq: event.Seq, - Workdir: event.Workdir, - } -} - -func isTerminalChatEvent(event *gatewayv1.ChatEvent) bool { - if event == nil { - return false + registry: newSessionRegistry(), + syncHub: newSyncHub(), + chatStore: newChatRunStore(), } - return event.GetType() == gatewayv1.ChatEvent_DONE || event.GetType() == gatewayv1.ChatEvent_ERROR } diff --git a/crates/agent-gateway/internal/session/manager_chat_runs.go b/crates/agent-gateway/internal/session/manager_chat_runs.go new file mode 100644 index 000000000..ef8eb3e35 --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_chat_runs.go @@ -0,0 +1,577 @@ +package session + +import ( + "encoding/json" + "sort" + "strings" + "sync" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (m *Manager) SubscribeChatEvents() (<-chan *ChatBroadcastEvent, func()) { + ch := make(chan *ChatBroadcastEvent, 128) + + m.chatStore.chatMu.Lock() + subID := m.chatStore.nextChatSubID + m.chatStore.nextChatSubID += 1 + m.chatStore.chatSubscribers[subID] = ch + m.chatStore.chatMu.Unlock() + + cleanup := func() { + m.chatStore.chatMu.Lock() + existing, ok := m.chatStore.chatSubscribers[subID] + if ok { + delete(m.chatStore.chatSubscribers, subID) + close(existing) + } + m.chatStore.chatMu.Unlock() + } + + return ch, cleanup +} + +func (m *Manager) StartChatRun(requestID string, conversationID string) (ChatRunSnapshot, error) { + snapshot, _, err := m.StartChatRunWithClientRequest(requestID, conversationID, "", "") + return snapshot, err +} + +func (m *Manager) StartChatRunWithClientRequest( + requestID string, + conversationID string, + clientRequestID string, + workdirInput ...string, +) (ChatRunSnapshot, bool, error) { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return ChatRunSnapshot{}, false, ErrChatRunNotFound + } + + now := time.Now() + conversationID = strings.TrimSpace(conversationID) + clientRequestID = strings.TrimSpace(clientRequestID) + workdir := "" + if len(workdirInput) > 0 { + workdir = strings.TrimSpace(workdirInput[0]) + } + sessionEpoch := m.currentSessionEpoch() + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + m.pruneExpiredChatRunsLocked(now) + + if clientRequestID != "" { + if existingRequestID := m.chatStore.chatRunByClientRequest[clientRequestID]; existingRequestID != "" { + if existing := m.chatStore.chatRuns[existingRequestID]; existing != nil { + if !existing.done { + if workdir != "" && existing.workdir == "" { + existing.workdir = workdir + } + return existing.snapshot(), false, nil + } + m.releaseCompletedChatRunLocked(existingRequestID, existing) + } + delete(m.chatStore.chatRunByClientRequest, clientRequestID) + } + } + + if existing := m.chatStore.chatRuns[requestID]; existing != nil { + m.removeChatRunLocked(requestID, existing) + } + + run := &chatRun{ + requestID: requestID, + conversationID: conversationID, + clientRequestID: clientRequestID, + workdir: workdir, + sessionEpoch: sessionEpoch, + updatedAt: now, + subscribers: make(map[int]*chatRunSubscriber), + } + m.chatStore.chatRuns[requestID] = run + if conversationID != "" { + m.chatStore.chatRunByConversation[conversationID] = requestID + } + if clientRequestID != "" { + m.chatStore.chatRunByClientRequest[clientRequestID] = requestID + } + + return run.snapshot(), true, nil +} + +func (m *Manager) RemoveChatRun(requestID string) { + requestID = strings.TrimSpace(requestID) + if requestID == "" { + return + } + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + run := m.chatStore.chatRuns[requestID] + if run == nil { + return + } + m.removeChatRunLocked(requestID, run) +} + +func (m *Manager) RemoveChatRunByConversation(conversationID string) { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + return + } + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + requestID := m.chatStore.chatRunByConversation[conversationID] + run := m.chatStore.chatRuns[requestID] + if run == nil { + for candidateRequestID, candidateRun := range m.chatStore.chatRuns { + if strings.TrimSpace(candidateRun.conversationID) == conversationID { + requestID = candidateRequestID + run = candidateRun + break + } + } + } + if run == nil { + return + } + m.removeChatRunLocked(requestID, run) +} + +func (m *Manager) ActiveChatRunSummaries() []ActiveChatRunSummary { + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + now := time.Now() + m.pruneExpiredChatRunsLocked(now) + + seen := make(map[string]int, len(m.chatStore.chatRuns)+len(m.chatStore.historyActiveRuns)) + summaries := make([]ActiveChatRunSummary, 0, len(m.chatStore.chatRuns)+len(m.chatStore.historyActiveRuns)) + for _, run := range m.chatStore.chatRuns { + if run == nil || run.done { + continue + } + conversationID := strings.TrimSpace(run.conversationID) + if conversationID == "" { + continue + } + summary := ActiveChatRunSummary{ + ConversationID: conversationID, + Workdir: strings.TrimSpace(run.workdir), + UpdatedAt: run.updatedAt.UnixMilli(), + } + if index, ok := seen[conversationID]; ok { + if summaries[index].Workdir == "" { + summaries[index].Workdir = summary.Workdir + } + if summary.UpdatedAt > summaries[index].UpdatedAt { + summaries[index].UpdatedAt = summary.UpdatedAt + } + continue + } + seen[conversationID] = len(summaries) + summaries = append(summaries, summary) + } + + for conversationID, run := range m.chatStore.historyActiveRuns { + conversationID = strings.TrimSpace(conversationID) + if conversationID == "" { + continue + } + if now.Sub(run.updatedAt) > chatRunStaleRetention { + delete(m.chatStore.historyActiveRuns, conversationID) + continue + } + workdir := strings.TrimSpace(run.workdir) + updatedAt := run.updatedAt.UnixMilli() + if index, ok := seen[conversationID]; ok { + if summaries[index].Workdir == "" { + summaries[index].Workdir = workdir + } + if updatedAt > summaries[index].UpdatedAt { + summaries[index].UpdatedAt = updatedAt + } + continue + } + seen[conversationID] = len(summaries) + summaries = append(summaries, ActiveChatRunSummary{ + ConversationID: conversationID, + Workdir: workdir, + UpdatedAt: updatedAt, + }) + } + + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].ConversationID < summaries[j].ConversationID + }) + return summaries +} + +func (m *Manager) ActiveChatRunConversationIDs() []string { + summaries := m.ActiveChatRunSummaries() + ids := make([]string, 0, len(summaries)) + for _, summary := range summaries { + if conversationID := strings.TrimSpace(summary.ConversationID); conversationID != "" { + ids = append(ids, conversationID) + } + } + return ids +} + +func (m *Manager) failOpenChatRunsForSessionEpoch(sessionEpoch uint64, message string) { + message = strings.TrimSpace(message) + if message == "" { + message = agentDisconnectedChatRunMessage + } + + data, err := json.Marshal(map[string]string{"message": message}) + if err != nil { + data = []byte(`{"message":"Desktop agent disconnected. Please retry."}`) + } + now := time.Now() + + type broadcastTarget struct { + event *ChatBroadcastEvent + subscribers []*chatRunSubscriber + } + targets := make([]broadcastTarget, 0) + globalSubscribers := make([]chan *ChatBroadcastEvent, 0) + + m.chatStore.chatMu.Lock() + m.pruneExpiredChatRunsLocked(now) + for requestID, run := range m.chatStore.chatRuns { + if run == nil || run.done || run.sessionEpoch != sessionEpoch { + continue + } + + run.nextSeq += 1 + run.updatedAt = now + run.done = true + run.expiresAt = now.Add(chatRunDoneRetention) + + chatEvent := &gatewayv1.ChatEvent{ + Type: gatewayv1.ChatEvent_ERROR, + ConversationId: strings.TrimSpace(run.conversationID), + Data: string(data), + } + broadcast := &ChatBroadcastEvent{ + RequestID: requestID, + Event: chatEvent, + Seq: run.nextSeq, + Workdir: strings.TrimSpace(run.workdir), + } + run.events = append(run.events, cloneChatBroadcastEvent(broadcast)) + if len(run.events) > maxBufferedChatRunEvents { + copy(run.events, run.events[len(run.events)-maxBufferedChatRunEvents:]) + run.events = run.events[:maxBufferedChatRunEvents] + } + + subscribers := make([]*chatRunSubscriber, 0, len(run.subscribers)) + for _, subscriber := range run.subscribers { + subscribers = append(subscribers, subscriber) + } + targets = append(targets, broadcastTarget{ + event: broadcast, + subscribers: subscribers, + }) + } + for _, ch := range m.chatStore.chatSubscribers { + globalSubscribers = append(globalSubscribers, ch) + } + m.chatStore.chatMu.Unlock() + + for _, target := range targets { + for _, subscriber := range target.subscribers { + select { + case <-subscriber.done: + case subscriber.ch <- cloneChatBroadcastEvent(target.event): + } + } + for _, ch := range globalSubscribers { + select { + case ch <- cloneChatBroadcastEvent(target.event): + default: + } + } + } +} + +func (m *Manager) SubscribeChatRun( + requestID string, + conversationID string, + afterSeq int64, +) (<-chan *ChatBroadcastEvent, <-chan struct{}, func(), ChatRunSnapshot, error) { + requestID = strings.TrimSpace(requestID) + conversationID = strings.TrimSpace(conversationID) + if afterSeq < 0 { + afterSeq = 0 + } + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + m.pruneExpiredChatRunsLocked(time.Now()) + + if requestID == "" && conversationID != "" { + requestID = m.chatStore.chatRunByConversation[conversationID] + } + run := m.chatStore.chatRuns[requestID] + if run == nil { + done := make(chan struct{}) + close(done) + return nil, done, func() {}, ChatRunSnapshot{}, ErrChatRunNotFound + } + + replay := make([]*ChatBroadcastEvent, 0) + for _, event := range run.events { + if event.Seq > afterSeq { + replay = append(replay, cloneChatBroadcastEvent(event)) + } + } + + bufferSize := len(replay) + 128 + if bufferSize < 128 { + bufferSize = 128 + } + ch := make(chan *ChatBroadcastEvent, bufferSize) + done := make(chan struct{}) + for _, event := range replay { + ch <- event + } + + subID := -1 + var subscriber *chatRunSubscriber + if !run.done { + subID = m.chatStore.nextChatRunSubID + m.chatStore.nextChatRunSubID += 1 + subscriber = &chatRunSubscriber{ + ch: ch, + done: done, + } + run.subscribers[subID] = subscriber + } + + var cleanupOnce sync.Once + cleanup := func() { + cleanupOnce.Do(func() { + m.chatStore.chatMu.Lock() + if subID >= 0 { + if current := m.chatStore.chatRuns[requestID]; current != nil { + delete(current.subscribers, subID) + } + } + m.chatStore.chatMu.Unlock() + if subscriber != nil { + subscriber.close() + } else { + close(done) + } + }) + } + + return ch, done, cleanup, run.snapshot(), nil +} + +func (m *Manager) broadcastChatEvent(requestID string, event *gatewayv1.ChatEvent) { + if event == nil { + return + } + + requestID = strings.TrimSpace(requestID) + conversationID := strings.TrimSpace(event.GetConversationId()) + now := time.Now() + sessionEpoch := m.currentSessionEpoch() + + m.chatStore.chatMu.Lock() + m.pruneExpiredChatRunsLocked(now) + broadcast := &ChatBroadcastEvent{ + RequestID: requestID, + Event: event, + } + var runSubscribers []*chatRunSubscriber + run := m.chatStore.chatRuns[requestID] + if run == nil && requestID != "" { + run = &chatRun{ + requestID: requestID, + conversationID: conversationID, + sessionEpoch: sessionEpoch, + updatedAt: now, + subscribers: make(map[int]*chatRunSubscriber), + } + m.chatStore.chatRuns[requestID] = run + if conversationID != "" { + m.chatStore.chatRunByConversation[conversationID] = requestID + } + } + if run != nil { + run.nextSeq += 1 + run.updatedAt = now + if conversationID != "" { + if run.conversationID != "" && run.conversationID != conversationID { + if m.chatStore.chatRunByConversation[run.conversationID] == requestID { + delete(m.chatStore.chatRunByConversation, run.conversationID) + } + } + run.conversationID = conversationID + m.chatStore.chatRunByConversation[conversationID] = requestID + if run.workdir == "" { + if activeRun, ok := m.chatStore.historyActiveRuns[conversationID]; ok { + run.workdir = strings.TrimSpace(activeRun.workdir) + } + } + } + broadcast.Seq = run.nextSeq + broadcast.Workdir = strings.TrimSpace(run.workdir) + run.events = append(run.events, cloneChatBroadcastEvent(broadcast)) + if len(run.events) > maxBufferedChatRunEvents { + copy(run.events, run.events[len(run.events)-maxBufferedChatRunEvents:]) + run.events = run.events[:maxBufferedChatRunEvents] + } + if isTerminalChatEvent(event) { + run.done = true + run.expiresAt = now.Add(chatRunDoneRetention) + } + runSubscribers = make([]*chatRunSubscriber, 0, len(run.subscribers)) + for _, subscriber := range run.subscribers { + runSubscribers = append(runSubscribers, subscriber) + } + } + subscribers := make([]chan *ChatBroadcastEvent, 0, len(m.chatStore.chatSubscribers)) + for _, ch := range m.chatStore.chatSubscribers { + subscribers = append(subscribers, ch) + } + m.chatStore.chatMu.Unlock() + + for _, subscriber := range runSubscribers { + select { + case <-subscriber.done: + case subscriber.ch <- cloneChatBroadcastEvent(broadcast): + } + } + for _, ch := range subscribers { + select { + case ch <- broadcast: + default: + } + } +} + +func (m *Manager) DispatchFromAgent(env *gatewayv1.AgentEnvelope) { + m.dispatchFromAgent(nil, env) +} + +func (m *Manager) DispatchFromAgentForSession(session *AgentSession, env *gatewayv1.AgentEnvelope) { + m.dispatchFromAgent(session, env) +} + +func (m *Manager) dispatchFromAgent(expected *AgentSession, env *gatewayv1.AgentEnvelope) { + m.registry.mu.RLock() + session := m.registry.session + m.registry.mu.RUnlock() + if session == nil || (expected != nil && session != expected) { + return + } + + if chatEvent := env.GetChatEvent(); chatEvent != nil { + m.broadcastChatEvent(env.GetRequestId(), chatEvent) + } + + if historySync := env.GetHistorySync(); historySync != nil { + m.broadcastHistorySync(historySync) + return + } + + if settingsSync := env.GetSettingsSync(); settingsSync != nil { + m.broadcastSettingsSync(settingsSync) + return + } + + if terminalEvent := env.GetTerminalEvent(); terminalEvent != nil { + m.broadcastTerminalEvent(terminalEvent) + return + } + + session.dispatch(env) +} + +func (r *chatRun) snapshot() ChatRunSnapshot { + firstSeq := int64(0) + if len(r.events) > 0 { + firstSeq = r.events[0].Seq + } + return ChatRunSnapshot{ + RequestID: r.requestID, + ConversationID: r.conversationID, + ClientRequestID: r.clientRequestID, + Workdir: r.workdir, + FirstSeq: firstSeq, + LatestSeq: r.nextSeq, + Done: r.done, + } +} + +func (s *chatRunSubscriber) close() { + s.closeOnce.Do(func() { + close(s.done) + }) +} + +func (m *Manager) pruneExpiredChatRunsLocked(now time.Time) { + for requestID, run := range m.chatStore.chatRuns { + if run == nil { + delete(m.chatStore.chatRuns, requestID) + continue + } + if run.done { + if !run.expiresAt.IsZero() && now.After(run.expiresAt) { + m.removeChatRunLocked(requestID, run) + } + continue + } + if !run.updatedAt.IsZero() && now.Sub(run.updatedAt) > chatRunStaleRetention { + m.removeChatRunLocked(requestID, run) + } + } +} + +func (m *Manager) removeChatRunLocked(requestID string, run *chatRun) { + if run.conversationID != "" && m.chatStore.chatRunByConversation[run.conversationID] == requestID { + delete(m.chatStore.chatRunByConversation, run.conversationID) + } + if run.clientRequestID != "" && m.chatStore.chatRunByClientRequest[run.clientRequestID] == requestID { + delete(m.chatStore.chatRunByClientRequest, run.clientRequestID) + } + delete(m.chatStore.chatRuns, requestID) + for _, subscriber := range run.subscribers { + subscriber.close() + } +} + +func (m *Manager) releaseCompletedChatRunLocked(requestID string, run *chatRun) { + if run.conversationID != "" && m.chatStore.chatRunByConversation[run.conversationID] == requestID { + delete(m.chatStore.chatRunByConversation, run.conversationID) + } + if run.clientRequestID != "" && m.chatStore.chatRunByClientRequest[run.clientRequestID] == requestID { + delete(m.chatStore.chatRunByClientRequest, run.clientRequestID) + } + delete(m.chatStore.chatRuns, requestID) +} + +func cloneChatBroadcastEvent(event *ChatBroadcastEvent) *ChatBroadcastEvent { + if event == nil { + return nil + } + return &ChatBroadcastEvent{ + RequestID: event.RequestID, + Event: event.Event, + Seq: event.Seq, + Workdir: event.Workdir, + } +} + +func isTerminalChatEvent(event *gatewayv1.ChatEvent) bool { + if event == nil { + return false + } + return event.GetType() == gatewayv1.ChatEvent_DONE || event.GetType() == gatewayv1.ChatEvent_ERROR +} diff --git a/crates/agent-gateway/internal/session/manager_history_sync.go b/crates/agent-gateway/internal/session/manager_history_sync.go new file mode 100644 index 000000000..37ab4fb6f --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_history_sync.go @@ -0,0 +1,137 @@ +package session + +import ( + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (m *Manager) SubscribeHistorySync() (<-chan *gatewayv1.HistorySyncEvent, func()) { + ch := make(chan *gatewayv1.HistorySyncEvent, 32) + + m.syncHub.historyMu.Lock() + subID := m.syncHub.nextHistorySubID + m.syncHub.nextHistorySubID += 1 + m.syncHub.historySubscribers[subID] = ch + m.syncHub.historyMu.Unlock() + + cleanup := func() { + m.syncHub.historyMu.Lock() + existing, ok := m.syncHub.historySubscribers[subID] + if ok { + delete(m.syncHub.historySubscribers, subID) + close(existing) + } + m.syncHub.historyMu.Unlock() + } + + return ch, cleanup +} + +func (m *Manager) broadcastHistorySync(event *gatewayv1.HistorySyncEvent) { + if event == nil { + return + } + + m.updateActiveHistoryRun(event) + m.releaseCompletedChatRunAfterHistoryUpsert(event) + + m.syncHub.historyMu.Lock() + subscribers := make([]chan *gatewayv1.HistorySyncEvent, 0, len(m.syncHub.historySubscribers)) + for _, ch := range m.syncHub.historySubscribers { + subscribers = append(subscribers, ch) + } + m.syncHub.historyMu.Unlock() + + for _, ch := range subscribers { + select { + case ch <- event: + default: + } + } +} + +func historySyncConversationID(event *gatewayv1.HistorySyncEvent) string { + conversationID := strings.TrimSpace(event.GetConversationId()) + if conversationID == "" && event.GetConversation() != nil { + conversationID = strings.TrimSpace(event.GetConversation().GetId()) + } + return conversationID +} + +func historySyncWorkdir(event *gatewayv1.HistorySyncEvent) string { + if event == nil || event.GetConversation() == nil { + return "" + } + return strings.TrimSpace(event.GetConversation().GetCwd()) +} + +func (m *Manager) updateActiveHistoryRun(event *gatewayv1.HistorySyncEvent) { + kind := strings.TrimSpace(event.GetKind()) + conversationID := historySyncConversationID(event) + if conversationID == "" { + return + } + + workdir := historySyncWorkdir(event) + now := time.Now() + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + m.pruneExpiredChatRunsLocked(now) + + switch kind { + case "running": + existing := m.chatStore.historyActiveRuns[conversationID] + if workdir == "" { + workdir = existing.workdir + } + m.chatStore.historyActiveRuns[conversationID] = activeHistoryRun{ + conversationID: conversationID, + workdir: workdir, + updatedAt: now, + } + if requestID := m.chatStore.chatRunByConversation[conversationID]; requestID != "" { + if run := m.chatStore.chatRuns[requestID]; run != nil && workdir != "" { + run.workdir = workdir + } + } + case "idle", "delete": + delete(m.chatStore.historyActiveRuns, conversationID) + case "upsert": + if workdir == "" { + return + } + if existing, ok := m.chatStore.historyActiveRuns[conversationID]; ok { + existing.workdir = workdir + existing.updatedAt = now + m.chatStore.historyActiveRuns[conversationID] = existing + } + if requestID := m.chatStore.chatRunByConversation[conversationID]; requestID != "" { + if run := m.chatStore.chatRuns[requestID]; run != nil { + run.workdir = workdir + } + } + } +} + +func (m *Manager) releaseCompletedChatRunAfterHistoryUpsert(event *gatewayv1.HistorySyncEvent) { + if strings.TrimSpace(event.GetKind()) != "upsert" { + return + } + + conversationID := historySyncConversationID(event) + if conversationID == "" { + return + } + + m.chatStore.chatMu.Lock() + defer m.chatStore.chatMu.Unlock() + requestID := m.chatStore.chatRunByConversation[conversationID] + run := m.chatStore.chatRuns[requestID] + if run == nil || !run.done { + return + } + m.releaseCompletedChatRunLocked(requestID, run) +} diff --git a/crates/agent-gateway/internal/session/manager_registry.go b/crates/agent-gateway/internal/session/manager_registry.go new file mode 100644 index 000000000..09609a19c --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_registry.go @@ -0,0 +1,141 @@ +package session + +import ( + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (m *Manager) RecordAuthentication(agentID, agentVersion, sessionID string) { + m.registry.mu.Lock() + defer m.registry.mu.Unlock() + m.registry.lastAuth = AuthSnapshot{ + AgentID: agentID, + AgentVersion: agentVersion, + SessionID: sessionID, + } + m.registry.authValid = true +} + +func (m *Manager) LatestAuthSnapshot() AuthSnapshot { + m.registry.mu.RLock() + defer m.registry.mu.RUnlock() + return m.registry.lastAuth +} + +func (m *Manager) IsOnline() bool { + m.registry.mu.RLock() + defer m.registry.mu.RUnlock() + return m.registry.session != nil +} + +func (m *Manager) SetSession(s *AgentSession) { + m.registry.mu.Lock() + previous := m.registry.session + previousEpoch := m.registry.sessionEpoch + if m.registry.authValid { + s.AgentID = m.registry.lastAuth.AgentID + s.AgentVersion = m.registry.lastAuth.AgentVersion + s.SessionID = m.registry.lastAuth.SessionID + } + if previous != s { + m.registry.sessionEpoch += 1 + } + sessionChanged := previous != s + m.registry.session = s + m.registry.mu.Unlock() + + if sessionChanged { + m.clearTerminalSessionSnapshot() + } + if previous != nil && previous != s { + previous.Close() + m.failOpenChatRunsForSessionEpoch(previousEpoch, agentDisconnectedChatRunMessage) + } +} + +func (m *Manager) ClearSession(session *AgentSession) { + m.registry.mu.Lock() + if m.registry.session != session { + m.registry.mu.Unlock() + return + } + clearedEpoch := m.registry.sessionEpoch + m.registry.session = nil + m.registry.mu.Unlock() + + if session == nil { + return + } + + session.Close() + m.clearTerminalSessionSnapshot() + m.failOpenChatRunsForSessionEpoch(clearedEpoch, agentDisconnectedChatRunMessage) +} + +func (m *Manager) Status() Status { + m.registry.mu.RLock() + defer m.registry.mu.RUnlock() + + status := Status{} + if m.registry.authValid { + status.AgentID = m.registry.lastAuth.AgentID + status.AgentVersion = m.registry.lastAuth.AgentVersion + status.SessionID = m.registry.lastAuth.SessionID + } + if m.registry.session == nil { + return status + } + status.Online = true + status.AgentID = m.registry.session.AgentID + status.AgentVersion = m.registry.session.AgentVersion + status.SessionID = m.registry.session.SessionID + status.ConnectedSince = m.registry.session.ConnectedAt.Unix() + status.LastHeartbeat = m.registry.session.LastPing.Unix() + return status +} + +func (m *Manager) TouchHeartbeat(session *AgentSession) { + m.registry.mu.Lock() + defer m.registry.mu.Unlock() + if m.registry.session == session { + m.registry.session.LastPing = time.Now() + } +} + +func (m *Manager) SendToAgent(env *gatewayv1.GatewayEnvelope) error { + m.registry.mu.RLock() + session := m.registry.session + m.registry.mu.RUnlock() + if session == nil { + return ErrAgentOffline + } + + return session.SendToAgent(env) +} + +func (m *Manager) currentSessionEpoch() uint64 { + m.registry.mu.RLock() + defer m.registry.mu.RUnlock() + return m.registry.sessionEpoch +} + +func (m *Manager) RegisterStream(requestID string) (<-chan *gatewayv1.AgentEnvelope, <-chan struct{}, func(), error) { + m.registry.mu.RLock() + session := m.registry.session + m.registry.mu.RUnlock() + if session == nil { + return nil, nil, nil, ErrAgentOffline + } + + stream, err := session.registerStream(requestID) + if err != nil { + return nil, nil, nil, err + } + + cleanup := func() { + session.unregisterStream(requestID, stream) + } + + return stream.ch, stream.done, cleanup, nil +} diff --git a/crates/agent-gateway/internal/session/manager_settings_sync.go b/crates/agent-gateway/internal/session/manager_settings_sync.go new file mode 100644 index 000000000..1651d55d5 --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_settings_sync.go @@ -0,0 +1,124 @@ +package session + +import ( + "encoding/json" + "strings" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (m *Manager) SubscribeSettingsSync() (<-chan *gatewayv1.SettingsSyncEvent, func()) { + ch := make(chan *gatewayv1.SettingsSyncEvent, 32) + + m.syncHub.settingsMu.Lock() + subID := m.syncHub.nextSettingsSubID + m.syncHub.nextSettingsSubID += 1 + m.syncHub.settingsSubscribers[subID] = ch + m.syncHub.settingsMu.Unlock() + + cleanup := func() { + m.syncHub.settingsMu.Lock() + existing, ok := m.syncHub.settingsSubscribers[subID] + if ok { + delete(m.syncHub.settingsSubscribers, subID) + close(existing) + } + m.syncHub.settingsMu.Unlock() + } + + return ch, cleanup +} + +func (m *Manager) WebTerminalEnabled() bool { + m.syncHub.settingsSnapshotMu.RLock() + defer m.syncHub.settingsSnapshotMu.RUnlock() + + remote, ok := m.syncHub.settingsSnapshot["remote"].(map[string]any) + if !ok { + return false + } + enabled, ok := remote["enableWebTerminal"].(bool) + return ok && enabled +} + +func (m *Manager) WebGitEnabled() bool { + m.syncHub.settingsSnapshotMu.RLock() + defer m.syncHub.settingsSnapshotMu.RUnlock() + + remote, ok := m.syncHub.settingsSnapshot["remote"].(map[string]any) + if !ok { + return false + } + enabled, ok := remote["enableWebGit"].(bool) + return ok && enabled +} + +func (m *Manager) updateSettingsSnapshot(event *gatewayv1.SettingsSyncEvent) { + if event == nil { + return + } + m.ApplySettingsJSON(event.GetSettingsJson()) +} + +func parseSettingsJSON(settingsJSON string) (map[string]any, bool) { + raw := strings.TrimSpace(settingsJSON) + if raw == "" { + return nil, false + } + var payload map[string]any + if err := json.Unmarshal([]byte(raw), &payload); err != nil || payload == nil { + return nil, false + } + return payload, true +} + +func (m *Manager) ApplySettingsJSON(settingsJSON string) { + payload, ok := parseSettingsJSON(settingsJSON) + if !ok { + return + } + m.syncHub.settingsSnapshotMu.Lock() + if _, hasIncomingRemote := payload["remote"]; !hasIncomingRemote { + if existingRemote, hasExistingRemote := m.syncHub.settingsSnapshot["remote"]; hasExistingRemote { + payload["remote"] = existingRemote + } + } + m.syncHub.settingsSnapshot = payload + m.syncHub.settingsSnapshotMu.Unlock() +} + +func (m *Manager) ApplySettingsJSONPreservingRemote(settingsJSON string) { + payload, ok := parseSettingsJSON(settingsJSON) + if !ok { + return + } + m.syncHub.settingsSnapshotMu.Lock() + if existingRemote, ok := m.syncHub.settingsSnapshot["remote"]; ok { + payload["remote"] = existingRemote + } else { + delete(payload, "remote") + } + m.syncHub.settingsSnapshot = payload + m.syncHub.settingsSnapshotMu.Unlock() +} + +func (m *Manager) broadcastSettingsSync(event *gatewayv1.SettingsSyncEvent) { + if event == nil { + return + } + m.updateSettingsSnapshot(event) + + m.syncHub.settingsMu.Lock() + subscribers := make([]chan *gatewayv1.SettingsSyncEvent, 0, len(m.syncHub.settingsSubscribers)) + for _, ch := range m.syncHub.settingsSubscribers { + subscribers = append(subscribers, ch) + } + m.syncHub.settingsMu.Unlock() + + for _, ch := range subscribers { + select { + case ch <- event: + default: + } + } +} diff --git a/crates/agent-gateway/internal/session/manager_state.go b/crates/agent-gateway/internal/session/manager_state.go new file mode 100644 index 000000000..a32a4528b --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_state.go @@ -0,0 +1,66 @@ +package session + +import ( + "sync" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +type sessionRegistry struct { + mu sync.RWMutex + session *AgentSession + sessionEpoch uint64 + lastAuth AuthSnapshot + authValid bool +} + +func newSessionRegistry() *sessionRegistry { + return &sessionRegistry{} +} + +type syncHub struct { + historyMu sync.Mutex + nextHistorySubID int + historySubscribers map[int]chan *gatewayv1.HistorySyncEvent + + settingsMu sync.Mutex + nextSettingsSubID int + settingsSubscribers map[int]chan *gatewayv1.SettingsSyncEvent + settingsSnapshotMu sync.RWMutex + settingsSnapshot map[string]any + + terminalMu sync.Mutex + nextTerminalSubID int + terminalSubscribers map[int]chan *gatewayv1.TerminalEvent + terminalSessions map[string]*gatewayv1.TerminalSession +} + +func newSyncHub() *syncHub { + return &syncHub{ + historySubscribers: make(map[int]chan *gatewayv1.HistorySyncEvent), + settingsSubscribers: make(map[int]chan *gatewayv1.SettingsSyncEvent), + terminalSubscribers: make(map[int]chan *gatewayv1.TerminalEvent), + terminalSessions: make(map[string]*gatewayv1.TerminalSession), + } +} + +type chatRunStore struct { + chatMu sync.Mutex + nextChatSubID int + chatSubscribers map[int]chan *ChatBroadcastEvent + nextChatRunSubID int + chatRuns map[string]*chatRun + chatRunByConversation map[string]string + chatRunByClientRequest map[string]string + historyActiveRuns map[string]activeHistoryRun +} + +func newChatRunStore() *chatRunStore { + return &chatRunStore{ + chatSubscribers: make(map[int]chan *ChatBroadcastEvent), + chatRuns: make(map[string]*chatRun), + chatRunByConversation: make(map[string]string), + chatRunByClientRequest: make(map[string]string), + historyActiveRuns: make(map[string]activeHistoryRun), + } +} diff --git a/crates/agent-gateway/internal/session/manager_terminal.go b/crates/agent-gateway/internal/session/manager_terminal.go new file mode 100644 index 000000000..54416c741 --- /dev/null +++ b/crates/agent-gateway/internal/session/manager_terminal.go @@ -0,0 +1,210 @@ +package session + +import ( + "sort" + "strings" + "time" + + gatewayv1 "github.com/liveagent/agent-gateway/internal/proto/v1" +) + +func (m *Manager) SubscribeTerminalEvents() (<-chan *gatewayv1.TerminalEvent, func()) { + ch := make(chan *gatewayv1.TerminalEvent, 4096) + + m.syncHub.terminalMu.Lock() + subID := m.syncHub.nextTerminalSubID + m.syncHub.nextTerminalSubID += 1 + m.syncHub.terminalSubscribers[subID] = ch + m.syncHub.terminalMu.Unlock() + + cleanup := func() { + m.syncHub.terminalMu.Lock() + existing, ok := m.syncHub.terminalSubscribers[subID] + if ok { + delete(m.syncHub.terminalSubscribers, subID) + close(existing) + } + m.syncHub.terminalMu.Unlock() + } + + return ch, cleanup +} + +func cloneTerminalSession(session *gatewayv1.TerminalSession) *gatewayv1.TerminalSession { + if session == nil { + return nil + } + return &gatewayv1.TerminalSession{ + Id: session.GetId(), + ProjectPathKey: session.GetProjectPathKey(), + Cwd: session.GetCwd(), + Shell: session.GetShell(), + Title: session.GetTitle(), + Pid: session.GetPid(), + Cols: session.GetCols(), + Rows: session.GetRows(), + CreatedAt: session.GetCreatedAt(), + UpdatedAt: session.GetUpdatedAt(), + FinishedAt: session.GetFinishedAt(), + ExitCode: session.GetExitCode(), + Running: session.GetRunning(), + } +} + +func terminalSessionSortKey(session *gatewayv1.TerminalSession) (string, uint64, string) { + if session == nil { + return "", 0, "" + } + return strings.TrimSpace(session.GetProjectPathKey()), session.GetCreatedAt(), strings.TrimSpace(session.GetId()) +} + +func sortTerminalSessions(sessions []*gatewayv1.TerminalSession) { + sort.Slice(sessions, func(i, j int) bool { + leftProject, leftCreatedAt, leftID := terminalSessionSortKey(sessions[i]) + rightProject, rightCreatedAt, rightID := terminalSessionSortKey(sessions[j]) + if leftProject != rightProject { + return leftProject < rightProject + } + if leftCreatedAt != rightCreatedAt { + return leftCreatedAt < rightCreatedAt + } + return leftID < rightID + }) +} + +func terminalSessionMatchesProject(session *gatewayv1.TerminalSession, projectPathKey string) bool { + projectPathKey = strings.TrimSpace(projectPathKey) + if projectPathKey == "" { + return true + } + if session == nil { + return false + } + return strings.TrimSpace(session.GetProjectPathKey()) == projectPathKey +} + +func (m *Manager) clearTerminalSessionSnapshot() { + m.syncHub.terminalMu.Lock() + m.syncHub.terminalSessions = make(map[string]*gatewayv1.TerminalSession) + m.syncHub.terminalMu.Unlock() +} + +func (m *Manager) TerminalSessionSnapshot(projectPathKey string) []*gatewayv1.TerminalSession { + projectPathKey = strings.TrimSpace(projectPathKey) + m.syncHub.terminalMu.Lock() + sessions := make([]*gatewayv1.TerminalSession, 0, len(m.syncHub.terminalSessions)) + for _, session := range m.syncHub.terminalSessions { + if !terminalSessionMatchesProject(session, projectPathKey) { + continue + } + if cloned := cloneTerminalSession(session); cloned != nil { + sessions = append(sessions, cloned) + } + } + m.syncHub.terminalMu.Unlock() + sortTerminalSessions(sessions) + return sessions +} + +func (m *Manager) ReplaceTerminalSessionSnapshot( + projectPathKey string, + sessions []*gatewayv1.TerminalSession, +) { + projectPathKey = strings.TrimSpace(projectPathKey) + m.syncHub.terminalMu.Lock() + if projectPathKey == "" { + m.syncHub.terminalSessions = make(map[string]*gatewayv1.TerminalSession) + } else { + for id, session := range m.syncHub.terminalSessions { + if terminalSessionMatchesProject(session, projectPathKey) { + delete(m.syncHub.terminalSessions, id) + } + } + } + for _, session := range sessions { + id := strings.TrimSpace(session.GetId()) + if id == "" { + continue + } + m.syncHub.terminalSessions[id] = cloneTerminalSession(session) + } + m.syncHub.terminalMu.Unlock() +} + +func (m *Manager) ApplyTerminalResponseSnapshot( + action string, + projectPathKey string, + resp *gatewayv1.TerminalResponse, +) { + if resp == nil { + return + } + action = strings.TrimSpace(action) + projectPathKey = strings.TrimSpace(projectPathKey) + + switch action { + case "list": + m.ReplaceTerminalSessionSnapshot(projectPathKey, resp.GetSessions()) + case "close_project": + m.ReplaceTerminalSessionSnapshot(projectPathKey, nil) + case "close": + if sessionID := strings.TrimSpace(resp.GetSession().GetId()); sessionID != "" { + m.syncHub.terminalMu.Lock() + delete(m.syncHub.terminalSessions, sessionID) + m.syncHub.terminalMu.Unlock() + } + case "create", "attach", "snapshot", "input", "resize", "rename": + session := resp.GetSession() + sessionID := strings.TrimSpace(session.GetId()) + if sessionID == "" { + return + } + m.syncHub.terminalMu.Lock() + m.syncHub.terminalSessions[sessionID] = cloneTerminalSession(session) + m.syncHub.terminalMu.Unlock() + } +} + +func (m *Manager) applyTerminalEventSnapshot(event *gatewayv1.TerminalEvent) { + if event == nil { + return + } + kind := strings.TrimSpace(event.GetKind()) + sessionID := strings.TrimSpace(event.GetSessionId()) + if sessionID == "" && event.GetSession() != nil { + sessionID = strings.TrimSpace(event.GetSession().GetId()) + } + if sessionID == "" { + return + } + + m.syncHub.terminalMu.Lock() + if kind == "closed" { + delete(m.syncHub.terminalSessions, sessionID) + } else if session := cloneTerminalSession(event.GetSession()); session != nil { + m.syncHub.terminalSessions[sessionID] = session + } + m.syncHub.terminalMu.Unlock() +} + +func (m *Manager) broadcastTerminalEvent(event *gatewayv1.TerminalEvent) { + if event == nil { + return + } + + m.applyTerminalEventSnapshot(event) + + m.syncHub.terminalMu.Lock() + subscribers := make([]chan *gatewayv1.TerminalEvent, 0, len(m.syncHub.terminalSubscribers)) + for _, ch := range m.syncHub.terminalSubscribers { + subscribers = append(subscribers, ch) + } + m.syncHub.terminalMu.Unlock() + + for _, ch := range subscribers { + select { + case ch <- event: + case <-time.After(50 * time.Millisecond): + } + } +} diff --git a/crates/agent-gateway/web/src/pages/LoginPage.tsx b/crates/agent-gateway/web/src/pages/LoginPage.tsx index 178ce7dd3..a43a8570f 100644 --- a/crates/agent-gateway/web/src/pages/LoginPage.tsx +++ b/crates/agent-gateway/web/src/pages/LoginPage.tsx @@ -1,5 +1,13 @@ import { useState } from "react"; -import { MessageSquareText, History, Timer, ArrowRight, Shield } from "../components/icons"; +import { + MessageSquareText, + History, + Timer, + ArrowRight, + Shield, + Key, + Lock, +} from "../components/icons"; import { Button } from "../components/ui/button"; import { Textarea } from "../components/ui/textarea"; import { cn } from "../lib/shared/utils"; @@ -17,19 +25,19 @@ const features = [ icon: MessageSquareText, title: "Remote Chat", desc: "按桌面端式样查看 token、thinking、tool_call 与 tool_result。", - accent: "login-feature-accent-blue", + accent: "login-feat--blue", }, { icon: History, title: "History Resume", desc: "从远程历史回填会话并继续对话,而不是只看原始 JSON。", - accent: "login-feature-accent-violet", + accent: "login-feat--violet", }, { icon: Timer, title: "Cron Control", desc: "在浏览器里完成任务查看、创建、更新与删除的转发调试。", - accent: "login-feature-accent-amber", + accent: "login-feat--amber", }, ]; @@ -44,71 +52,99 @@ export function LoginPage({ return (
- {/* Ambient background glow */} -