Skip to content

Commit 2959dc9

Browse files
authored
Merge pull request #830 from rumpl/remote-oauth
Add unmanaged mode for oauth flow
2 parents 914eee1 + be219c2 commit 2959dc9

18 files changed

Lines changed: 543 additions & 262 deletions

File tree

pkg/agent/agent_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ func (s *stubToolSet) Tools(context.Context) ([]tools.Tool, error) {
3737
func (s *stubToolSet) Instructions() string { return s.instructions }
3838
func (s *stubToolSet) SetElicitationHandler(tools.ElicitationHandler) {}
3939
func (s *stubToolSet) SetOAuthSuccessHandler(func()) {}
40+
func (s *stubToolSet) SetManagedOAuth(bool) {}
4041

4142
func TestAgentTools(t *testing.T) {
4243
tests := []struct {

pkg/api/types.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,6 @@ type DesktopTokenResponse struct {
155155
Token string `json:"token"`
156156
}
157157

158-
// ResumeStartOauthRequest represents the user approval to start the OAuth flow
159-
type ResumeStartOauthRequest struct {
160-
Confirmation bool `json:"confirmation"`
161-
}
162-
163-
// ResumeCodeReceivedOauthRequest represents the response from getting the OAuth URL with code and state
164-
type ResumeCodeReceivedOauthRequest struct {
165-
Code string `json:"code"`
166-
State string `json:"state"`
167-
}
168-
169158
// ResumeElicitationRequest represents a request to resume with an elicitation response
170159
type ResumeElicitationRequest struct {
171160
Action string `json:"action"` // "accept", "decline", or "cancel"

pkg/app/app.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/docker/cagent/pkg/runtime"
1111
"github.com/docker/cagent/pkg/session"
12+
"github.com/docker/cagent/pkg/tools"
1213
)
1314

1415
type App struct {
@@ -87,9 +88,12 @@ func (a *App) Subscribe(ctx context.Context, program *tea.Program) {
8788

8889
// Resume resumes the runtime with the given confirmation type
8990
func (a *App) Resume(resumeType runtime.ResumeType) {
90-
if a.runtime != nil {
91-
a.runtime.Resume(context.Background(), resumeType)
92-
}
91+
a.runtime.Resume(context.Background(), resumeType)
92+
}
93+
94+
// ResumeElicitation resumes an elicitation request with the given action and content
95+
func (a *App) ResumeElicitation(ctx context.Context, action tools.ElicitationAction, content map[string]any) error {
96+
return a.runtime.ResumeElicitation(ctx, action, content)
9397
}
9498

9599
func (a *App) NewSession() {
@@ -105,7 +109,7 @@ func (a *App) Session() *session.Session {
105109
}
106110

107111
func (a *App) CompactSession() {
108-
if a.runtime != nil && a.session != nil {
112+
if a.session != nil {
109113
events := make(chan runtime.Event, 100)
110114
a.runtime.Summarize(context.Background(), a.session, events)
111115
close(events)
@@ -115,14 +119,6 @@ func (a *App) CompactSession() {
115119
}
116120
}
117121

118-
// ResumeStartOAuth resumes the runtime with OAuth authorization confirmation
119-
func (a *App) ResumeStartOAuth(bool) {
120-
if a.runtime != nil {
121-
// TODO(rumpl): handle the error
122-
_ = a.runtime.ResumeElicitation(context.Background(), "accept", nil)
123-
}
124-
}
125-
126122
func (a *App) PlainTextTranscript() string {
127123
return transcript(a.session)
128124
}

pkg/runtime/client.go

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/docker/cagent/pkg/api"
1717
v2 "github.com/docker/cagent/pkg/config/v2"
1818
"github.com/docker/cagent/pkg/session"
19+
"github.com/docker/cagent/pkg/tools"
1920
)
2021

2122
// Client is an HTTP client for the cagent server API
@@ -373,17 +374,7 @@ func (c *Client) runAgentWithAgentName(ctx context.Context, sessionID, agent, ag
373374
return eventChan, nil
374375
}
375376

376-
func (c *Client) ResumeStartAuthorizationFlow(ctx context.Context, id string, confirmation bool) error {
377-
req := api.ResumeStartOauthRequest{Confirmation: confirmation}
378-
return c.doRequest(ctx, http.MethodPost, "/api/"+id+"/resumeStartOauth", req, nil)
379-
}
380-
381-
func (c *Client) ResumeCodeReceived(ctx context.Context, code, state string) error {
382-
req := api.ResumeCodeReceivedOauthRequest{Code: code, State: state}
383-
return c.doRequest(ctx, http.MethodPost, "/api/resumeCodeReceivedOauth", req, nil)
384-
}
385-
386-
func (c *Client) ResumeElicitation(ctx context.Context, action string, content map[string]any) error {
387-
req := api.ResumeElicitationRequest{Action: action, Content: content}
388-
return c.doRequest(ctx, http.MethodPost, "/api/resumeElicitation", req, nil)
377+
func (c *Client) ResumeElicitation(ctx context.Context, sessionID string, action tools.ElicitationAction, content map[string]any) error {
378+
req := api.ResumeElicitationRequest{Action: string(action), Content: content}
379+
return c.doRequest(ctx, http.MethodPost, "/api/sessions/"+sessionID+"/elicitation", req, nil)
389380
}

pkg/runtime/event.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,14 +293,14 @@ func ElicitationRequest(message string, schema any, meta map[string]any, agentNa
293293
func (e *ElicitationRequestEvent) GetAgentName() string { return e.AgentName }
294294

295295
type AuthorizationEvent struct {
296-
Type string `json:"type"`
297-
Confirmation string `json:"confirmation"` // only "confirmed"
296+
Type string `json:"type"`
297+
Confirmation tools.ElicitationAction `json:"confirmation"`
298298
AgentContext
299299
}
300300

301301
func (e *AuthorizationEvent) GetAgentName() string { return "" }
302302

303-
func Authorization(confirmation, agentName string) Event {
303+
func Authorization(confirmation tools.ElicitationAction, agentName string) Event {
304304
return &AuthorizationEvent{
305305
Type: "authorization_event",
306306
Confirmation: confirmation,

pkg/runtime/remote_runtime.go

Lines changed: 177 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,30 @@ package runtime
22

33
import (
44
"context"
5+
"encoding/json"
56
"fmt"
67
"log/slog"
8+
"time"
9+
10+
"golang.org/x/oauth2"
711

812
"github.com/docker/cagent/pkg/api"
913
"github.com/docker/cagent/pkg/chat"
1014
latest "github.com/docker/cagent/pkg/config/v2"
1115
"github.com/docker/cagent/pkg/session"
1216
"github.com/docker/cagent/pkg/team"
17+
"github.com/docker/cagent/pkg/tools"
18+
"github.com/docker/cagent/pkg/tools/mcp"
1319
)
1420

1521
// RemoteRuntime implements the Interface using a remote client
1622
type RemoteRuntime struct {
17-
client *Client
18-
currentAgent string
19-
agentFilename string
20-
sessionID string
21-
team *team.Team
23+
client *Client
24+
currentAgent string
25+
agentFilename string
26+
sessionID string
27+
team *team.Team
28+
pendingOAuthElicitation *ElicitationRequestEvent
2229
}
2330

2431
// RemoteRuntimeOption is a function for configuring the RemoteRuntime
@@ -115,6 +122,10 @@ func (r *RemoteRuntime) RunStream(ctx context.Context, sess *session.Session) <-
115122
}
116123

117124
for streamEvent := range streamChan {
125+
if elicitationRequest, ok := streamEvent.(*ElicitationRequestEvent); ok {
126+
// Store pending OAuth elicitation request
127+
r.pendingOAuthElicitation = elicitationRequest
128+
}
118129
events <- streamEvent
119130
}
120131
}()
@@ -176,51 +187,186 @@ func (r *RemoteRuntime) convertSessionMessages(sess *session.Session) []api.Mess
176187
return messages
177188
}
178189

179-
// ResumeStartAuthorizationFlow allows resuming execution after user confirmation
180-
func (r *RemoteRuntime) ResumeStartAuthorizationFlow(ctx context.Context, confirmationType bool) {
181-
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "confirmation_type", confirmationType, "session_id", r.sessionID)
190+
// ResumeElicitation sends an elicitation response back to a waiting elicitation request
191+
func (r *RemoteRuntime) ResumeElicitation(ctx context.Context, action tools.ElicitationAction, content map[string]any) error {
192+
slog.Debug("Resuming remote runtime with elicitation response", "agent", r.currentAgent, "action", action, "session_id", r.sessionID)
182193

183-
if r.sessionID == "" {
184-
slog.Error("Cannot resume: no session ID available")
185-
return
194+
err := r.handleOAuthElicitation(ctx, r.pendingOAuthElicitation)
195+
if err != nil {
196+
return err
186197
}
198+
// TODO: once we get here and the elicitation is the OAuth type, we need to start the managed OAuth flow
187199

188-
if err := r.client.ResumeStartAuthorizationFlow(ctx, r.sessionID, confirmationType); err != nil {
189-
slog.Error("Failed to resume remote session", "error", err, "session_id", r.sessionID)
200+
if err := r.client.ResumeElicitation(ctx, r.sessionID, action, content); err != nil {
201+
return err
190202
}
203+
204+
return nil
191205
}
192206

193-
// ResumeCodeReceived allows resuming execution after user confirmation
194-
func (r *RemoteRuntime) ResumeCodeReceived(ctx context.Context, code, state string) error {
195-
slog.Debug("Resuming remote runtime", "agent", r.currentAgent, "code", code, "state", state, "session_id", r.sessionID)
207+
// HandleOAuthElicitation handles OAuth elicitation requests from remote MCP servers
208+
func (r *RemoteRuntime) handleOAuthElicitation(ctx context.Context, req *ElicitationRequestEvent) error {
209+
slog.Debug("Handling OAuth elicitation request", "server_url", req.Meta["cagent/server_url"])
196210

197-
if r.sessionID == "" {
198-
slog.Error("Cannot resume: no session ID available")
199-
return fmt.Errorf("session ID cannot be empty")
211+
// Extract OAuth parameters from metadata
212+
serverURL, ok := req.Meta["cagent/server_url"].(string)
213+
if !ok {
214+
err := fmt.Errorf("server_url missing from elicitation metadata")
215+
slog.Error("Failed to extract server_url", "error", err)
216+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
217+
return err
200218
}
201219

202-
if err := r.client.ResumeCodeReceived(ctx, code, state); err != nil {
203-
slog.Error("Failed to resume remote session", "error", err, "session_id", r.sessionID)
220+
// Extract authorization server metadata
221+
authServerMetadata, ok := req.Meta["auth_server_metadata"].(map[string]any)
222+
if !ok {
223+
err := fmt.Errorf("auth_server_metadata missing from elicitation metadata")
224+
slog.Error("Failed to extract auth_server_metadata", "error", err)
225+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
204226
return err
205227
}
206228

207-
return nil
208-
}
229+
// Unmarshal authorization server metadata
230+
var authMetadata mcp.AuthorizationServerMetadata
231+
metadataBytes, err := json.Marshal(authServerMetadata)
232+
if err != nil {
233+
slog.Error("Failed to marshal auth_server_metadata", "error", err)
234+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
235+
return fmt.Errorf("failed to marshal auth_server_metadata: %w", err)
236+
}
237+
if err := json.Unmarshal(metadataBytes, &authMetadata); err != nil {
238+
slog.Error("Failed to unmarshal auth_server_metadata", "error", err)
239+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
240+
return fmt.Errorf("failed to unmarshal auth_server_metadata: %w", err)
241+
}
209242

210-
// ResumeElicitation sends an elicitation response back to a waiting elicitation request
211-
func (r *RemoteRuntime) ResumeElicitation(ctx context.Context, action string, content map[string]any) error {
212-
slog.Debug("Resuming remote runtime with elicitation response", "agent", r.currentAgent, "action", action, "session_id", r.sessionID)
243+
slog.Debug("Authorization server metadata extracted", "issuer", authMetadata.Issuer)
213244

214-
if r.sessionID == "" {
215-
slog.Error("Cannot resume: no session ID available")
216-
return fmt.Errorf("session ID cannot be empty")
245+
// Create timeout context for OAuth flow (5 minutes)
246+
oauthCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
247+
defer cancel()
248+
249+
// Create and start callback server
250+
slog.Debug("Creating OAuth callback server")
251+
callbackServer, err := mcp.NewCallbackServer()
252+
if err != nil {
253+
slog.Error("Failed to create callback server", "error", err)
254+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
255+
return fmt.Errorf("failed to create callback server: %w", err)
217256
}
257+
defer func() {
258+
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
259+
defer shutdownCancel()
260+
if err := callbackServer.Shutdown(shutdownCtx); err != nil {
261+
slog.Error("Failed to shutdown callback server", "error", err)
262+
}
263+
}()
264+
265+
if err := callbackServer.Start(); err != nil {
266+
slog.Error("Failed to start callback server", "error", err)
267+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
268+
return fmt.Errorf("failed to start callback server: %w", err)
269+
}
270+
271+
redirectURI := callbackServer.GetRedirectURI()
272+
slog.Debug("Callback server started", "redirect_uri", redirectURI)
273+
274+
// Register client
275+
var clientID, clientSecret string
276+
if authMetadata.RegistrationEndpoint != "" {
277+
slog.Debug("Attempting dynamic client registration")
278+
clientID, clientSecret, err = mcp.RegisterClient(oauthCtx, &authMetadata, redirectURI, nil)
279+
if err != nil {
280+
slog.Error("Dynamic client registration failed", "error", err)
281+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
282+
return fmt.Errorf("failed to register client: %w", err)
283+
}
284+
slog.Debug("Client registered successfully", "client_id", clientID)
285+
} else {
286+
err := fmt.Errorf("authorization server does not support dynamic client registration")
287+
slog.Error("Client registration not supported", "error", err)
288+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
289+
return err
290+
}
291+
292+
// Generate state and PKCE verifier
293+
state, err := mcp.GenerateState()
294+
if err != nil {
295+
slog.Error("Failed to generate state", "error", err)
296+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
297+
return fmt.Errorf("failed to generate state: %w", err)
298+
}
299+
300+
callbackServer.SetExpectedState(state)
301+
verifier := mcp.GeneratePKCEVerifier()
302+
303+
// Build authorization URL
304+
authURL := mcp.BuildAuthorizationURL(
305+
authMetadata.AuthorizationEndpoint,
306+
clientID,
307+
redirectURI,
308+
state,
309+
oauth2.S256ChallengeFromVerifier(verifier),
310+
serverURL,
311+
)
218312

219-
if err := r.client.ResumeElicitation(ctx, action, content); err != nil {
220-
slog.Error("Failed to resume remote session with elicitation", "error", err, "session_id", r.sessionID)
313+
slog.Debug("Authorization URL built", "url", authURL)
314+
315+
// Request authorization code (this opens the browser)
316+
slog.Debug("Requesting authorization code")
317+
code, receivedState, err := mcp.RequestAuthorizationCode(oauthCtx, authURL, callbackServer, state)
318+
if err != nil {
319+
slog.Error("Failed to get authorization code", "error", err)
320+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
321+
return fmt.Errorf("failed to get authorization code: %w", err)
322+
}
323+
324+
if receivedState != state {
325+
err := fmt.Errorf("state mismatch: expected %s, got %s", state, receivedState)
326+
slog.Error("State mismatch in authorization response", "error", err)
327+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
221328
return err
222329
}
223330

331+
slog.Debug("Authorization code received, exchanging for token")
332+
333+
// Exchange code for token
334+
token, err := mcp.ExchangeCodeForToken(
335+
oauthCtx,
336+
authMetadata.TokenEndpoint,
337+
code,
338+
verifier,
339+
clientID,
340+
clientSecret,
341+
redirectURI,
342+
)
343+
if err != nil {
344+
slog.Error("Failed to exchange code for token", "error", err)
345+
_ = r.client.ResumeElicitation(ctx, r.sessionID, "decline", nil)
346+
return fmt.Errorf("failed to exchange code for token: %w", err)
347+
}
348+
349+
slog.Debug("Token obtained successfully", "token_type", token.TokenType)
350+
351+
// Send token back to server via ResumeElicitation
352+
tokenData := map[string]any{
353+
"access_token": token.AccessToken,
354+
"token_type": token.TokenType,
355+
}
356+
if token.ExpiresIn > 0 {
357+
tokenData["expires_in"] = token.ExpiresIn
358+
}
359+
if token.RefreshToken != "" {
360+
tokenData["refresh_token"] = token.RefreshToken
361+
}
362+
363+
slog.Debug("Sending token to server")
364+
if err := r.client.ResumeElicitation(ctx, r.sessionID, tools.ElicitationActionAccept, tokenData); err != nil {
365+
slog.Error("Failed to send token to server", "error", err)
366+
return fmt.Errorf("failed to send token to server: %w", err)
367+
}
368+
369+
slog.Debug("OAuth flow completed successfully")
224370
return nil
225371
}
226372

0 commit comments

Comments
 (0)