Skip to content

Commit aec6628

Browse files
authored
Merge pull request #346 from rumpl/oauth-only-when-needed
Only start the callback server when it's needed
2 parents 46c8a15 + 2a13cda commit aec6628

4 files changed

Lines changed: 112 additions & 31 deletions

File tree

cmd/root/run.go

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import (
2323
"github.com/docker/cagent/pkg/chat"
2424
"github.com/docker/cagent/pkg/content"
2525
"github.com/docker/cagent/pkg/evaluation"
26-
"github.com/docker/cagent/pkg/oauth"
2726
"github.com/docker/cagent/pkg/remote"
2827
"github.com/docker/cagent/pkg/runtime"
2928
"github.com/docker/cagent/pkg/session"
@@ -181,31 +180,6 @@ func doRunCommand(ctx context.Context, args []string, exec bool) error {
181180
agentFilename = tmpFile.Name()
182181
}
183182

184-
// Set up OAuth redirect URI for CLI/TUI mode
185-
if runConfig.RedirectURI == "" {
186-
runConfig.RedirectURI = "http://localhost:8083/oauth-callback"
187-
slog.Debug("Set default OAuth redirect URI for CLI/TUI mode", "redirectURI", runConfig.RedirectURI)
188-
189-
// Start OAuth callback server for CLI/TUI mode
190-
callbackServer := oauth.NewCallbackServer(8083)
191-
err := callbackServer.Start(ctx)
192-
if err != nil {
193-
slog.Warn("Failed to start OAuth callback server", "error", err)
194-
} else {
195-
defer func() {
196-
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
197-
defer cancel()
198-
if err := callbackServer.Stop(shutdownCtx); err != nil {
199-
slog.Error("Failed to stop OAuth callback server", "error", err)
200-
}
201-
}()
202-
slog.Debug("Started OAuth callback server", "port", 8083)
203-
204-
// Set up global callback server for OAuth manager
205-
oauth.SetGlobalCallbackServer(callbackServer)
206-
}
207-
}
208-
209183
agents, err = teamloader.Load(ctx, agentFilename, runConfig)
210184
if err != nil {
211185
return err

pkg/oauth/interfaces.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ type Manager interface {
1212

1313
// SendAuthorizationCode sends the OAuth authorization code after user has completed the OAuth flow
1414
SendAuthorizationCode(ctx context.Context, code string) error
15+
16+
// Cleanup stops any owned resources like callback servers
17+
Cleanup(ctx context.Context) error
1518
}
1619

1720
// ServerInfoProvider interface for toolsets that can provide server information

pkg/oauth/manager.go

Lines changed: 104 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"log/slog"
7+
"sync"
78
"time"
89

910
"github.com/mark3labs/mcp-go/client"
@@ -15,14 +16,48 @@ type manager struct {
1516
emitAuthRequired func(serverURL, serverType, status string)
1617
resumeAuthorizeOauthFlow chan bool
1718
resumeOauthCodeReceived chan string
19+
callbackServer *CallbackServer
20+
serverMutex sync.Mutex
21+
redirectURI string
22+
port int
1823
}
1924

20-
// NewManager creates a new OAuth manager
21-
func NewManager(emitAuthRequired func(serverURL, serverType, status string)) Manager {
22-
return &manager{
25+
// NewManager creates a new OAuth manager with optional port configuration
26+
func NewManager(emitAuthRequired func(serverURL, serverType, status string), opts ...ManagerOption) Manager {
27+
m := &manager{
2328
emitAuthRequired: emitAuthRequired,
2429
resumeAuthorizeOauthFlow: make(chan bool),
2530
resumeOauthCodeReceived: make(chan string),
31+
port: 8083,
32+
}
33+
34+
// Apply options
35+
for _, opt := range opts {
36+
opt(m)
37+
}
38+
39+
// Set redirect URI based on port
40+
if m.redirectURI == "" {
41+
m.redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", m.port)
42+
}
43+
44+
return m
45+
}
46+
47+
// ManagerOption configures the OAuth manager
48+
type ManagerOption func(*manager)
49+
50+
// WithPort sets the callback server port
51+
func WithPort(port int) ManagerOption {
52+
return func(m *manager) {
53+
m.port = port
54+
}
55+
}
56+
57+
// WithRedirectURI sets a custom redirect URI
58+
func WithRedirectURI(uri string) ManagerOption {
59+
return func(m *manager) {
60+
m.redirectURI = uri
2661
}
2762
}
2863

@@ -136,8 +171,13 @@ func (m *manager) performOAuthAuthorization(ctx context.Context, sessionID strin
136171
slog.Debug("Waiting for OAuth authorization code")
137172
var code string
138173

139-
// Check if we have a global callback server running
140-
if callbackServer := GetGlobalCallbackServer(); callbackServer != nil {
174+
// Ensure callback server is started if needed
175+
if err := m.ensureCallbackServer(ctx); err != nil {
176+
slog.Warn("Failed to start callback server, falling back to manual input", "error", err)
177+
}
178+
179+
// Check if we have a callback server running (either global or our own)
180+
if callbackServer := m.getCallbackServer(); callbackServer != nil {
141181
slog.Debug("Using callback server for OAuth authorization")
142182
// Wait for callback from the browser
143183
callbackCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
@@ -188,3 +228,62 @@ func (m *manager) performOAuthAuthorization(ctx context.Context, sessionID strin
188228
slog.Info("OAuth authorization completed successfully")
189229
return nil
190230
}
231+
232+
// ensureCallbackServer starts the callback server if it's not already running
233+
func (m *manager) ensureCallbackServer(ctx context.Context) error {
234+
m.serverMutex.Lock()
235+
defer m.serverMutex.Unlock()
236+
237+
// Check if there's already a global callback server
238+
if globalServer := GetGlobalCallbackServer(); globalServer != nil {
239+
slog.Debug("Using existing global callback server")
240+
return nil
241+
}
242+
243+
// Check if we already have our own server
244+
if m.callbackServer != nil {
245+
slog.Debug("Callback server already started")
246+
return nil
247+
}
248+
249+
// Create and start new callback server
250+
slog.Debug("Starting OAuth callback server on demand", "port", m.port)
251+
m.callbackServer = NewCallbackServer(m.port)
252+
if err := m.callbackServer.Start(ctx); err != nil {
253+
return fmt.Errorf("failed to start OAuth callback server: %w", err)
254+
}
255+
256+
// Set as global server so other components can use it
257+
SetGlobalCallbackServer(m.callbackServer)
258+
259+
slog.Debug("OAuth callback server started successfully", "port", m.port)
260+
return nil
261+
}
262+
263+
// getCallbackServer returns the active callback server (either global or local)
264+
func (m *manager) getCallbackServer() *CallbackServer {
265+
// Prefer global server first
266+
if globalServer := GetGlobalCallbackServer(); globalServer != nil {
267+
return globalServer
268+
}
269+
270+
// Fall back to our local server
271+
return m.callbackServer
272+
}
273+
274+
// Cleanup stops the callback server if we own it
275+
func (m *manager) Cleanup(ctx context.Context) error {
276+
m.serverMutex.Lock()
277+
defer m.serverMutex.Unlock()
278+
279+
if m.callbackServer != nil {
280+
slog.Debug("Stopping OAuth callback server")
281+
if err := m.callbackServer.Stop(ctx); err != nil {
282+
slog.Error("Failed to stop OAuth callback server", "error", err)
283+
return err
284+
}
285+
m.callbackServer = nil
286+
}
287+
288+
return nil
289+
}

pkg/runtime/runtime.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,11 @@ func (r *runtime) handleOAuthAuthorizationFlow(ctx context.Context, sess *sessio
140140
events <- AuthorizationRequired(serverURL, serverType, status)
141141
}
142142
r.oauthManager = oauth.NewManager(emitAuthRequired)
143+
defer func() {
144+
if cleanupErr := r.oauthManager.Cleanup(ctx); cleanupErr != nil {
145+
slog.Error("Failed to cleanup OAuth manager", "error", cleanupErr)
146+
}
147+
}()
143148
}
144149

145150
return r.oauthManager.HandleAuthorizationFlow(ctx, sess.ID, oauthRequiredErr)

0 commit comments

Comments
 (0)