Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions cmd/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ var notionAPITokenPattern = regexp.MustCompile(`^ntn_[A-Za-z0-9]{20,}$`)

const officialAPIIntegrationsURL = "https://www.notion.so/profile/integrations/internal"

type AuthLoginCmd struct{}
type AuthLoginCmd struct {
Tunnel bool `help:"Use a tunnel for the OAuth callback, allowing login from a remote machine without a local browser" default:"false"`
}

func (c *AuthLoginCmd) Run(ctx *Context) error {
tokenStore, err := mcp.NewFileTokenStore()
Expand All @@ -44,7 +46,10 @@ func (c *AuthLoginCmd) Run(ctx *Context) error {
}

bgCtx := context.Background()
if err := mcp.RunOAuthFlow(bgCtx, tokenStore); err != nil {
opts := &mcp.OAuthFlowOptions{
Tunnel: c.Tunnel,
}
if err := mcp.RunOAuthFlow(bgCtx, tokenStore, opts); err != nil {
output.PrintError(err)
return err
}
Expand Down
83 changes: 78 additions & 5 deletions internal/mcp/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"runtime"
"time"

"github.com/lox/notion-cli/internal/tunnel"
"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
Expand Down Expand Up @@ -51,15 +52,78 @@ type OAuthResult struct {
Error string
}

func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error {
// OAuthFlowOptions configures the OAuth login flow.
type OAuthFlowOptions struct {
// Tunnel starts a localtunnel to expose the local callback server,
// allowing authentication from a machine without a local browser.
Tunnel bool
}

// isAuthenticated attempts a quick MCP Initialize to check whether valid
// credentials already exist in the token store.
func isAuthenticated(ctx context.Context, tokenStore *FileTokenStore) bool {
cfg := transport.OAuthConfig{TokenStore: tokenStore}
t, err := transport.NewStreamableHTTP(DefaultEndpoint, transport.WithHTTPOAuth(cfg))
if err != nil {
return false
}
c := client.NewClient(t)
defer func() { _ = c.Close() }()
if err := c.Start(ctx); err != nil {
return false
}
initReq := mcp.InitializeRequest{}
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
initReq.Params.ClientInfo = mcp.Implementation{Name: "notion-cli", Version: "0.1.0"}
_, err = c.Initialize(ctx, initReq)
return err == nil
}

func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore, opts *OAuthFlowOptions) error {
if opts == nil {
opts = &OAuthFlowOptions{}
}

// Check if already authenticated before starting expensive resources
// (tunnels, listeners). This avoids requiring tunnel availability when
// the user already has valid credentials.
if isAuthenticated(ctx, tokenStore) {
fmt.Println("Already authenticated!")
return nil
}

listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return fmt.Errorf("start callback server: %w", err)
}
defer func() { _ = listener.Close() }()

port := listener.Addr().(*net.TCPAddr).Port
redirectURI := fmt.Sprintf("http://localhost:%d%s", port, callbackPath)

// When tunneled, the callback is publicly reachable, so use an
// unguessable per-attempt path to prevent unsolicited requests from
// consuming the callback before the real OAuth redirect arrives.
cbPath := callbackPath
if opts.Tunnel {
nonce, err := GenerateState() // reuse the same random generator
if err != nil {
return fmt.Errorf("generate callback nonce: %w", err)
}
cbPath = callbackPath + "/" + nonce
}

redirectURI := fmt.Sprintf("http://localhost:%d%s", port, cbPath)

if opts.Tunnel {
fmt.Println("Starting tunnel...")
tun, err := tunnel.Start(ctx, port)
if err != nil {
return fmt.Errorf("start tunnel: %w", err)
Comment on lines +117 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Check existing auth before opening the tunnel

Starting the tunnel before mcpClient.Initialize makes auth login --tunnel depend on localtunnel availability even when no OAuth flow is needed. If a user is already authenticated (the Initialize call would succeed), this path now fails early at tunnel.Start(...) on offline/restricted networks instead of returning Already authenticated!, which is a regression introduced by this ordering.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in a28e9d5 — added an isAuthenticated pre-check at the top of RunOAuthFlow that tries a quick Initialize before starting any listener or tunnel. If credentials are already valid, it returns immediately without touching localtunnel.

}
defer tun.Close()
redirectURI = tun.URL + cbPath
fmt.Printf("Tunnel active: %s\n", tun.URL)
}

oauthConfig := transport.OAuthConfig{
RedirectURI: redirectURI,
Expand Down Expand Up @@ -129,7 +193,7 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error {

server := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != callbackPath {
if r.URL.Path != cbPath {
http.NotFound(w, r)
return
}
Expand Down Expand Up @@ -166,10 +230,19 @@ func RunOAuthFlow(ctx context.Context, tokenStore *FileTokenStore) error {
fmt.Println()
fmt.Printf(" %s\n", authURL)
fmt.Println()

if opts.Tunnel {
fmt.Println("NOTE: After authenticating, you may see a tunnel interstitial page.")
fmt.Println("Click \"Click to Continue\" to complete the callback.")
fmt.Println()
}

fmt.Println("Waiting for authentication...")

if err := OpenBrowser(authURL); err != nil {
fmt.Printf("(Could not open browser automatically: %v)\n", err)
if !opts.Tunnel {
if err := OpenBrowser(authURL); err != nil {
fmt.Printf("(Could not open browser automatically: %v)\n", err)
}
}

select {
Expand Down
183 changes: 183 additions & 0 deletions internal/tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package tunnel

import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
)

// DefaultServer is the public localtunnel.me instance.
const DefaultServer = "https://localtunnel.me"

// Tunnel proxies connections from a public URL to a local port using the
// localtunnel protocol (https://github.com/localtunnel/server).
type Tunnel struct {
// URL is the public HTTPS URL assigned by the tunnel server.
URL string

localPort int
remoteHost string
remotePort int
maxConn int
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}

type assignment struct {
ID string `json:"id"`
Port int `json:"port"`
URL string `json:"url"`
MaxConnCount int `json:"max_conn_count"`
}

// Start opens a tunnel from the default localtunnel.me server to localPort.
func Start(ctx context.Context, localPort int) (*Tunnel, error) {
return StartWithServer(ctx, localPort, DefaultServer)
}

// StartWithServer opens a tunnel using a localtunnel-compatible server.
func StartWithServer(ctx context.Context, localPort int, serverURL string) (*Tunnel, error) {
parsed, err := url.Parse(serverURL)
if err != nil {
return nil, fmt.Errorf("parse server URL: %w", err)
}

info, err := requestTunnel(ctx, serverURL)
if err != nil {
return nil, err
}

if info.URL == "" || info.Port == 0 {
return nil, fmt.Errorf("invalid tunnel assignment: missing URL or port")
}

tctx, cancel := context.WithCancel(ctx)

t := &Tunnel{
URL: info.URL,
localPort: localPort,
remoteHost: parsed.Hostname(),
remotePort: info.Port,
maxConn: info.MaxConnCount,
ctx: tctx,
cancel: cancel,
}

if t.maxConn <= 0 {
t.maxConn = 10
}

for i := 0; i < t.maxConn; i++ {
t.wg.Add(1)
go t.worker()
}

return t, nil
}

func requestTunnel(ctx context.Context, serverURL string) (*assignment, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, serverURL+"/?new", nil)
if err != nil {
return nil, fmt.Errorf("build tunnel request: %w", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("request tunnel: %w", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 512))
return nil, fmt.Errorf("tunnel server returned %d: %s", resp.StatusCode, string(body))
}

var info assignment
if err := json.NewDecoder(resp.Body).Decode(&info); err != nil {
return nil, fmt.Errorf("decode tunnel assignment: %w", err)
}

return &info, nil
}

func (t *Tunnel) worker() {
defer t.wg.Done()
for {
select {
case <-t.ctx.Done():
return
default:
if err := t.proxy(); err != nil {
// Brief pause before reconnecting on error.
select {
case <-t.ctx.Done():
return
case <-time.After(time.Second):
}
}
}
}
}

func (t *Tunnel) proxy() error {
d := net.Dialer{Timeout: 10 * time.Second}
remote, err := d.DialContext(t.ctx, "tcp", fmt.Sprintf("%s:%d", t.remoteHost, t.remotePort))
if err != nil {
return err
}

// Ensure the remote connection is closed when the context is cancelled,
// which unblocks the blocking ReadFull below.
proxyDone := make(chan struct{})
defer close(proxyDone)
go func() {
select {
case <-t.ctx.Done():
remote.Close()
case <-proxyDone:
}
}()
defer remote.Close()

// Block until the tunnel server forwards a request to this connection.
header := make([]byte, 1)
if _, err := io.ReadFull(remote, header); err != nil {
return err
}

local, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", t.localPort), 5*time.Second)
if err != nil {
return err
}
defer local.Close()

// Forward the byte we already consumed.
if _, err := local.Write(header); err != nil {
return err
}

// Bidirectional copy.
errc := make(chan error, 2)
go func() { _, err := io.Copy(local, remote); errc <- err }()
go func() { _, err := io.Copy(remote, local); errc <- err }()

select {
case <-errc:
case <-t.ctx.Done():
}

return nil
}

// Close shuts down the tunnel and waits for all proxy workers to exit.
func (t *Tunnel) Close() {
t.cancel()
t.wg.Wait()
}
Loading