From fc205bc018d8b1ca37474725cb2ce15f3f2870bf Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Sat, 18 Apr 2026 11:48:24 -0700 Subject: [PATCH 1/3] feat(auth): add --tunnel flag for OAuth login on headless/remote machines Adds `notion-cli auth login --tunnel` which uses localtunnel to expose the OAuth callback server via a public HTTPS URL. This enables OAuth login from headless machines or remote servers that don't have a local browser available. Implements a pure Go localtunnel client with no new dependencies. Co-Authored-By: Claude Opus 4.6 (1M context) --- cmd/auth.go | 9 +- internal/mcp/oauth.go | 38 +++++- internal/tunnel/tunnel.go | 183 +++++++++++++++++++++++++ internal/tunnel/tunnel_test.go | 235 +++++++++++++++++++++++++++++++++ 4 files changed, 460 insertions(+), 5 deletions(-) create mode 100644 internal/tunnel/tunnel.go create mode 100644 internal/tunnel/tunnel_test.go diff --git a/cmd/auth.go b/cmd/auth.go index 4d1e74e..0b49502 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -34,7 +34,9 @@ var notionAPITokenPattern = regexp.MustCompile(`^ntn_[A-Za-z0-9]{20,}$`) const officialAPIIntegrationsURL = "https://www.notion.so/profile/integrations/internal" -type AuthLoginCmd struct{} +type AuthLoginCmd struct { + Tunnel bool `help:"Use a tunnel for the OAuth callback, allowing login from a remote machine without a local browser" default:"false"` +} func (c *AuthLoginCmd) Run(ctx *Context) error { tokenStore, err := mcp.NewFileTokenStore() @@ -44,7 +46,10 @@ func (c *AuthLoginCmd) Run(ctx *Context) error { } bgCtx := context.Background() - if err := mcp.RunOAuthFlow(bgCtx, tokenStore); err != nil { + opts := &mcp.OAuthFlowOptions{ + Tunnel: c.Tunnel, + } + if err := mcp.RunOAuthFlow(bgCtx, tokenStore, opts); err != nil { output.PrintError(err) return err } diff --git a/internal/mcp/oauth.go b/internal/mcp/oauth.go index d47a6ee..e95d7ea 100644 --- a/internal/mcp/oauth.go +++ b/internal/mcp/oauth.go @@ -13,6 +13,7 @@ import ( "runtime" "time" + "github.com/lox/notion-cli/internal/tunnel" "github.com/mark3labs/mcp-go/client" "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" @@ -51,7 +52,18 @@ type OAuthResult struct { Error string } -func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error { +// OAuthFlowOptions configures the OAuth login flow. +type OAuthFlowOptions struct { + // Tunnel starts a localtunnel to expose the local callback server, + // allowing authentication from a machine without a local browser. + Tunnel bool +} + +func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFlowOptions) error { + if opts == nil { + opts = &OAuthFlowOptions{} + } + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return fmt.Errorf("start callback server: %w", err) @@ -61,6 +73,17 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error { port := listener.Addr().(*net.TCPAddr).Port redirectURI := fmt.Sprintf("http://localhost:%d%s", port, callbackPath) + if opts.Tunnel { + fmt.Println("Starting tunnel...") + tun, err := tunnel.Start(ctx, port) + if err != nil { + return fmt.Errorf("start tunnel: %w", err) + } + defer tun.Close() + redirectURI = tun.URL + callbackPath + fmt.Printf("Tunnel active: %s\n", tun.URL) + } + oauthConfig := transport.OAuthConfig{ RedirectURI: redirectURI, TokenStore: tokenStore, @@ -166,10 +189,19 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error { fmt.Println() fmt.Printf(" %s\n", authURL) fmt.Println() + + if opts.Tunnel { + fmt.Println("NOTE: After authenticating, you may see a tunnel interstitial page.") + fmt.Println("Click \"Click to Continue\" to complete the callback.") + fmt.Println() + } + fmt.Println("Waiting for authentication...") - if err := OpenBrowser(authURL); err != nil { - fmt.Printf("(Could not open browser automatically: %v)\n", err) + if !opts.Tunnel { + if err := OpenBrowser(authURL); err != nil { + fmt.Printf("(Could not open browser automatically: %v)\n", err) + } } select { diff --git a/internal/tunnel/tunnel.go b/internal/tunnel/tunnel.go new file mode 100644 index 0000000..acff437 --- /dev/null +++ b/internal/tunnel/tunnel.go @@ -0,0 +1,183 @@ +package tunnel + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "sync" + "time" +) + +// DefaultServer is the public localtunnel.me instance. +const DefaultServer = "https://localtunnel.me" + +// Tunnel proxies connections from a public URL to a local port using the +// localtunnel protocol (https://github.com/localtunnel/server). +type Tunnel struct { + // URL is the public HTTPS URL assigned by the tunnel server. + URL string + + localPort int + remoteHost string + remotePort int + maxConn int + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +type assignment struct { + ID string `json:"id"` + Port int `json:"port"` + URL string `json:"url"` + MaxConnCount int `json:"max_conn_count"` +} + +// Start opens a tunnel from the default localtunnel.me server to localPort. +func Start(ctx context.Context, localPort int) (*Tunnel, error) { + return StartWithServer(ctx, localPort, DefaultServer) +} + +// StartWithServer opens a tunnel using a localtunnel-compatible server. +func StartWithServer(ctx context.Context, localPort int, serverURL string) (*Tunnel, error) { + parsed, err := url.Parse(serverURL) + if err != nil { + return nil, fmt.Errorf("parse server URL: %w", err) + } + + info, err := requestTunnel(ctx, serverURL) + if err != nil { + return nil, err + } + + if info.URL == "" || info.Port == 0 { + return nil, fmt.Errorf("invalid tunnel assignment: missing URL or port") + } + + tctx, cancel := context.WithCancel(ctx) + + t := &Tunnel{ + URL: info.URL, + localPort: localPort, + remoteHost: parsed.Hostname(), + remotePort: info.Port, + maxConn: info.MaxConnCount, + ctx: tctx, + cancel: cancel, + } + + if t.maxConn <= 0 { + t.maxConn = 10 + } + + for i := 0; i < t.maxConn; i++ { + t.wg.Add(1) + go t.worker() + } + + return t, nil +} + +func requestTunnel(ctx context.Context, serverURL string) (*assignment, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, serverURL+"/?new", nil) + if err != nil { + return nil, fmt.Errorf("build tunnel request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("request tunnel: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 512)) + return nil, fmt.Errorf("tunnel server returned %d: %s", resp.StatusCode, string(body)) + } + + var info assignment + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return nil, fmt.Errorf("decode tunnel assignment: %w", err) + } + + return &info, nil +} + +func (t *Tunnel) worker() { + defer t.wg.Done() + for { + select { + case <-t.ctx.Done(): + return + default: + if err := t.proxy(); err != nil { + // Brief pause before reconnecting on error. + select { + case <-t.ctx.Done(): + return + case <-time.After(time.Second): + } + } + } + } +} + +func (t *Tunnel) proxy() error { + d := net.Dialer{Timeout: 10 * time.Second} + remote, err := d.DialContext(t.ctx, "tcp", fmt.Sprintf("%s:%d", t.remoteHost, t.remotePort)) + if err != nil { + return err + } + + // Ensure the remote connection is closed when the context is cancelled, + // which unblocks the blocking ReadFull below. + proxyDone := make(chan struct{}) + defer close(proxyDone) + go func() { + select { + case <-t.ctx.Done(): + remote.Close() + case <-proxyDone: + } + }() + defer remote.Close() + + // Block until the tunnel server forwards a request to this connection. + header := make([]byte, 1) + if _, err := io.ReadFull(remote, header); err != nil { + return err + } + + local, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", t.localPort), 5*time.Second) + if err != nil { + return err + } + defer local.Close() + + // Forward the byte we already consumed. + if _, err := local.Write(header); err != nil { + return err + } + + // Bidirectional copy. + errc := make(chan error, 2) + go func() { _, err := io.Copy(local, remote); errc <- err }() + go func() { _, err := io.Copy(remote, local); errc <- err }() + + select { + case <-errc: + case <-t.ctx.Done(): + } + + return nil +} + +// Close shuts down the tunnel and waits for all proxy workers to exit. +func (t *Tunnel) Close() { + t.cancel() + t.wg.Wait() +} diff --git a/internal/tunnel/tunnel_test.go b/internal/tunnel/tunnel_test.go new file mode 100644 index 0000000..4cc6386 --- /dev/null +++ b/internal/tunnel/tunnel_test.go @@ -0,0 +1,235 @@ +package tunnel + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +// mockTunnelServer simulates a localtunnel-compatible server. +// It accepts tunnel assignment requests and forwards connections between the +// public side and the pooled client connections. +type mockTunnelServer struct { + // proxyListener is where the tunnel client opens pooled connections. + proxyListener net.Listener + // httpServer is the assignment endpoint. + httpServer *httptest.Server + + mu sync.Mutex + pool []net.Conn + poolCond *sync.Cond +} + +func newMockTunnelServer(t *testing.T) *mockTunnelServer { + t.Helper() + + proxyListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen proxy: %v", err) + } + + m := &mockTunnelServer{ + proxyListener: proxyListener, + } + m.poolCond = sync.NewCond(&m.mu) + + proxyPort := proxyListener.Addr().(*net.TCPAddr).Port + + m.httpServer = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + info := assignment{ + ID: "test-tunnel", + Port: proxyPort, + URL: fmt.Sprintf("http://127.0.0.1:%d", proxyPort), + MaxConnCount: 2, + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(info) + })) + + // Accept pooled connections from the tunnel client. + go func() { + for { + conn, err := proxyListener.Accept() + if err != nil { + return + } + m.mu.Lock() + m.pool = append(m.pool, conn) + m.poolCond.Broadcast() + m.mu.Unlock() + } + }() + + t.Cleanup(func() { + m.httpServer.Close() + proxyListener.Close() + }) + + return m +} + +// getPooledConn blocks until a pooled connection is available and returns it. +func (m *mockTunnelServer) getPooledConn(timeout time.Duration) (net.Conn, bool) { + deadline := time.After(timeout) + done := make(chan struct{}) + + var conn net.Conn + + go func() { + m.mu.Lock() + defer m.mu.Unlock() + for len(m.pool) == 0 { + m.poolCond.Wait() + } + conn = m.pool[0] + m.pool = m.pool[1:] + close(done) + }() + + select { + case <-done: + return conn, true + case <-deadline: + return nil, false + } +} + +func TestTunnelProxiesHTTPRequest(t *testing.T) { + // Start a local HTTP server (the "application" behind the tunnel). + localListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen local: %v", err) + } + localPort := localListener.Addr().(*net.TCPAddr).Port + + localServer := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "hello from local") + }), + } + go localServer.Serve(localListener) + defer localServer.Close() + + // Set up mock tunnel server. + mock := newMockTunnelServer(t) + + // Start the tunnel client. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + tun, err := StartWithServer(ctx, localPort, mock.httpServer.URL) + if err != nil { + t.Fatalf("StartWithServer: %v", err) + } + defer tun.Close() + + if tun.URL == "" { + t.Fatal("tunnel URL is empty") + } + + // Wait for a pooled connection to appear. + pooledConn, ok := mock.getPooledConn(5 * time.Second) + if !ok { + t.Fatal("no pooled connection from tunnel client") + } + + // Simulate a request arriving at the tunnel: write an HTTP request to the + // pooled connection and read back the response. + reqStr := "GET /callback?code=abc&state=xyz HTTP/1.1\r\nHost: test\r\nConnection: close\r\n\r\n" + if _, err := pooledConn.Write([]byte(reqStr)); err != nil { + t.Fatalf("write to pooled conn: %v", err) + } + + pooledConn.SetReadDeadline(time.Now().Add(5 * time.Second)) + respBytes, err := io.ReadAll(pooledConn) + if err != nil { + t.Fatalf("read response: %v", err) + } + + resp := string(respBytes) + if !strings.Contains(resp, "hello from local") { + t.Fatalf("unexpected response: %s", resp) + } +} + +func TestTunnelCloseStopsWorkers(t *testing.T) { + // Local listener that we never actually serve, just need the port. + localListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + localPort := localListener.Addr().(*net.TCPAddr).Port + defer localListener.Close() + + mock := newMockTunnelServer(t) + + ctx := context.Background() + tun, err := StartWithServer(ctx, localPort, mock.httpServer.URL) + if err != nil { + t.Fatalf("StartWithServer: %v", err) + } + + // Close should return promptly without hanging. + done := make(chan struct{}) + go func() { + tun.Close() + close(done) + }() + + select { + case <-done: + // ok + case <-time.After(5 * time.Second): + t.Fatal("Close() did not return within timeout") + } +} + +func TestRequestTunnelBadStatus(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "overloaded", http.StatusServiceUnavailable) + })) + defer srv.Close() + + _, err := requestTunnel(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error for non-200 status") + } + if !strings.Contains(err.Error(), "503") { + t.Fatalf("error should mention status code: %v", err) + } +} + +func TestRequestTunnelBadJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "not json") + })) + defer srv.Close() + + _, err := requestTunnel(context.Background(), srv.URL) + if err == nil { + t.Fatal("expected error for bad JSON") + } +} + +func TestStartWithServerRejectsMissingURL(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + json.NewEncoder(w).Encode(assignment{Port: 1234}) + })) + defer srv.Close() + + _, err := StartWithServer(context.Background(), 8080, srv.URL) + if err == nil { + t.Fatal("expected error for missing URL in assignment") + } + if !strings.Contains(err.Error(), "missing URL or port") { + t.Fatalf("unexpected error: %v", err) + } +} From d188c16a2fd4258e8e0a513d21c744b176e475c0 Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Sat, 18 Apr 2026 11:55:57 -0700 Subject: [PATCH 2/3] fix(auth): use unguessable callback path in tunnel mode When the callback server is exposed via a public tunnel URL, the fixed /callback path is guessable. Generate a random per-attempt nonce so the callback path becomes /callback/, preventing unsolicited requests from consuming the callback before the real OAuth redirect arrives. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/mcp/oauth.go | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/internal/mcp/oauth.go b/internal/mcp/oauth.go index e95d7ea..405db7d 100644 --- a/internal/mcp/oauth.go +++ b/internal/mcp/oauth.go @@ -71,7 +71,20 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFl defer func() { _ = listener.Close() }() port := listener.Addr().(*net.TCPAddr).Port - redirectURI := fmt.Sprintf("http://localhost:%d%s", port, callbackPath) + + // When tunneled, the callback is publicly reachable, so use an + // unguessable per-attempt path to prevent unsolicited requests from + // consuming the callback before the real OAuth redirect arrives. + cbPath := callbackPath + if opts.Tunnel { + nonce, err := GenerateState() // reuse the same random generator + if err != nil { + return fmt.Errorf("generate callback nonce: %w", err) + } + cbPath = callbackPath + "/" + nonce + } + + redirectURI := fmt.Sprintf("http://localhost:%d%s", port, cbPath) if opts.Tunnel { fmt.Println("Starting tunnel...") @@ -80,7 +93,7 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFl return fmt.Errorf("start tunnel: %w", err) } defer tun.Close() - redirectURI = tun.URL + callbackPath + redirectURI = tun.URL + cbPath fmt.Printf("Tunnel active: %s\n", tun.URL) } @@ -152,7 +165,7 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFl server := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != callbackPath { + if r.URL.Path != cbPath { http.NotFound(w, r) return } From a28e9d5049d410cee7f72701bcaf9fafd90e35c8 Mon Sep 17 00:00:00 2001 From: Daniyal Khan Date: Mon, 20 Apr 2026 21:17:53 -0700 Subject: [PATCH 3/3] fix(auth): check existing auth before starting tunnel Moves the authentication check before tunnel/listener setup so that `auth login --tunnel` doesn't depend on localtunnel availability when the user already has valid credentials. Co-Authored-By: Claude Opus 4.6 (1M context) --- internal/mcp/oauth.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/internal/mcp/oauth.go b/internal/mcp/oauth.go index 405db7d..0493acf 100644 --- a/internal/mcp/oauth.go +++ b/internal/mcp/oauth.go @@ -59,11 +59,39 @@ type OAuthFlowOptions struct { Tunnel bool } +// isAuthenticated attempts a quick MCP Initialize to check whether valid +// credentials already exist in the token store. +func isAuthenticated(ctx context.Context, tokenStore *FileTokenStore) bool { + cfg := transport.OAuthConfig{TokenStore: tokenStore} + t, err := transport.NewStreamableHTTP(DefaultEndpoint, transport.WithHTTPOAuth(cfg)) + if err != nil { + return false + } + c := client.NewClient(t) + defer func() { _ = c.Close() }() + if err := c.Start(ctx); err != nil { + return false + } + initReq := mcp.InitializeRequest{} + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initReq.Params.ClientInfo = mcp.Implementation{Name: "notion-cli", Version: "0.1.0"} + _, err = c.Initialize(ctx, initReq) + return err == nil +} + func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFlowOptions) error { if opts == nil { opts = &OAuthFlowOptions{} } + // Check if already authenticated before starting expensive resources + // (tunnels, listeners). This avoids requiring tunnel availability when + // the user already has valid credentials. + if isAuthenticated(ctx, tokenStore) { + fmt.Println("Already authenticated!") + return nil + } + listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return fmt.Errorf("start callback server: %w", err)