Skip to content

Commit a5b37a3

Browse files
authored
fix(security): prevent Windows command injection and improve RFC 7009 compliance (#30)
* fix(security): prevent Windows command injection and improve RFC 7009 compliance - Use rundll32 instead of cmd /c start to prevent shell metacharacter injection in URLs on Windows - Replace panic with fmt.Fprintf + os.Exit(1) for consistent error handling in loadConfig - Write token save warning to stderr instead of stdout in refreshAccessToken - Add token_type_hint parameter to revocation requests per RFC 7009 * fix(cli): address Copilot review feedback - Align error message wording to say "retry HTTP client" in loadConfig - Add token_type_hint assertions to revocation tests for RFC 7009 compliance * test(revocation): assert call count in access-only revocation test - Track revocation call count to ensure exactly one request is made when only an access token exists (no refresh token) * fix(cli): address Copilot review round 3 - Conditionally include token_type_hint only when non-empty for server compatibility - Extract browserCommand helper for testability and add unit tests covering darwin, windows, and linux command construction
1 parent f358ea1 commit a5b37a3

6 files changed

Lines changed: 136 additions & 30 deletions

File tree

auth.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"net/http"
88
"net/url"
9+
"os"
910

1011
retry "github.com/appleboy/go-httpretry"
1112
"github.com/go-authgate/cli/tui"
@@ -94,7 +95,7 @@ func refreshAccessToken(
9495
}
9596

9697
if err := cfg.Store.Save(cfg.ClientID, *storage); err != nil {
97-
fmt.Printf("Warning: Failed to save refreshed tokens: %v\n", err)
98+
fmt.Fprintf(os.Stderr, "Warning: Failed to save refreshed tokens: %v\n", err)
9899
}
99100
return storage, nil
100101
}

browser.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@ import (
77
"runtime"
88
)
99

10-
// openBrowser attempts to open url in the user's default browser.
11-
// Returns an error if launching the browser fails, but callers should
12-
// always print the URL as a fallback regardless of the error.
13-
func openBrowser(ctx context.Context, url string) error {
14-
var cmd *exec.Cmd
15-
16-
switch runtime.GOOS {
10+
// browserCommand returns the executable name and arguments for opening a URL
11+
// on the given OS. This is extracted for testability.
12+
func browserCommand(goos, url string) (name string, args []string) {
13+
switch goos {
1714
case "darwin":
18-
cmd = exec.CommandContext(ctx, "open", url)
15+
return "open", []string{url}
1916
case "windows":
20-
cmd = exec.CommandContext(ctx, "cmd", "/c", "start", url)
17+
return "rundll32", []string{"url.dll,FileProtocolHandler", url}
2118
default:
22-
cmd = exec.CommandContext(ctx, "xdg-open", url)
19+
return "xdg-open", []string{url}
2320
}
21+
}
22+
23+
// openBrowser attempts to open url in the user's default browser.
24+
// Returns an error if launching the browser fails, but callers should
25+
// always print the URL as a fallback regardless of the error.
26+
func openBrowser(ctx context.Context, url string) error {
27+
name, args := browserCommand(runtime.GOOS, url)
28+
cmd := exec.CommandContext(ctx, name, args...)
2429

2530
if err := cmd.Start(); err != nil {
2631
return fmt.Errorf("failed to open browser: %w", err)

browser_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package main
2+
3+
import "testing"
4+
5+
func TestBrowserCommand(t *testing.T) {
6+
tests := []struct {
7+
goos string
8+
url string
9+
wantName string
10+
wantArgs []string
11+
}{
12+
{
13+
goos: "darwin",
14+
url: "https://example.com/auth",
15+
wantName: "open",
16+
wantArgs: []string{"https://example.com/auth"},
17+
},
18+
{
19+
goos: "windows",
20+
url: "https://example.com/auth",
21+
wantName: "rundll32",
22+
wantArgs: []string{"url.dll,FileProtocolHandler", "https://example.com/auth"},
23+
},
24+
{
25+
goos: "linux",
26+
url: "https://example.com/auth",
27+
wantName: "xdg-open",
28+
wantArgs: []string{"https://example.com/auth"},
29+
},
30+
{
31+
goos: "windows",
32+
url: "https://example.com/auth?foo=bar&baz=qux",
33+
wantName: "rundll32",
34+
wantArgs: []string{
35+
"url.dll,FileProtocolHandler",
36+
"https://example.com/auth?foo=bar&baz=qux",
37+
},
38+
},
39+
}
40+
41+
for _, tt := range tests {
42+
t.Run(tt.goos, func(t *testing.T) {
43+
name, args := browserCommand(tt.goos, tt.url)
44+
if name != tt.wantName {
45+
t.Errorf("name: got %q, want %q", name, tt.wantName)
46+
}
47+
if len(args) != len(tt.wantArgs) {
48+
t.Fatalf("args length: got %d, want %d", len(args), len(tt.wantArgs))
49+
}
50+
for i, arg := range args {
51+
if arg != tt.wantArgs[i] {
52+
t.Errorf("args[%d]: got %q, want %q", i, arg, tt.wantArgs[i])
53+
}
54+
}
55+
})
56+
}
57+
}

config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ func loadConfig() *AppConfig {
236236
var err error
237237
cfg.RetryClient, err = retry.NewBackgroundClient(retry.WithHTTPClient(baseHTTPClient))
238238
if err != nil {
239-
panic(fmt.Sprintf("failed to create retry client: %v", err))
239+
fmt.Fprintf(os.Stderr, "Error: failed to create retry HTTP client: %v\n", err)
240+
os.Exit(1)
240241
}
241242

242243
// Resolve timeout configuration.

token_cmd.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,14 @@ func revokeTokenOnServer(
172172

173173
if tok.RefreshToken != "" {
174174
wg.Go(func() {
175-
if err := doRevoke(ctx, cfg, revokeURL, tok.RefreshToken, timeout); err != nil {
175+
if err := doRevoke(
176+
ctx,
177+
cfg,
178+
revokeURL,
179+
tok.RefreshToken,
180+
"refresh_token",
181+
timeout,
182+
); err != nil {
176183
mu.Lock()
177184
refreshErr = err
178185
mu.Unlock()
@@ -182,7 +189,14 @@ func revokeTokenOnServer(
182189

183190
if tok.AccessToken != "" {
184191
wg.Go(func() {
185-
if err := doRevoke(ctx, cfg, revokeURL, tok.AccessToken, timeout); err != nil {
192+
if err := doRevoke(
193+
ctx,
194+
cfg,
195+
revokeURL,
196+
tok.AccessToken,
197+
"access_token",
198+
timeout,
199+
); err != nil {
186200
mu.Lock()
187201
accessErr = err
188202
mu.Unlock()
@@ -213,6 +227,7 @@ func doRevoke(
213227
cfg *AppConfig,
214228
revokeURL string,
215229
token string,
230+
tokenTypeHint string,
216231
timeout time.Duration,
217232
) error {
218233
ctx, cancel := context.WithTimeout(ctx, timeout)
@@ -222,6 +237,9 @@ func doRevoke(
222237
"token": {token},
223238
"client_id": {cfg.ClientID},
224239
}
240+
if tokenTypeHint != "" {
241+
data.Set("token_type_hint", tokenTypeHint)
242+
}
225243
if !cfg.IsPublicClient() {
226244
data.Set("client_secret", cfg.ClientSecret)
227245
}

token_cmd_test.go

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ func TestRunTokenDelete(t *testing.T) {
106106

107107
func TestRunTokenDelete_ServerRevocation(t *testing.T) {
108108
t.Run("successful revocation and local delete", func(t *testing.T) {
109-
var revokedTokens []string
109+
type revokeCall struct {
110+
token string
111+
tokenTypeHint string
112+
}
113+
var revokeCalls []revokeCall
110114
var mu sync.Mutex
111115
srv := httptest.NewServer(
112116
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -115,7 +119,10 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
115119
return
116120
}
117121
mu.Lock()
118-
revokedTokens = append(revokedTokens, r.FormValue("token"))
122+
revokeCalls = append(revokeCalls, revokeCall{
123+
token: r.FormValue("token"),
124+
tokenTypeHint: r.FormValue("token_type_hint"),
125+
})
119126
mu.Unlock()
120127
w.WriteHeader(http.StatusOK)
121128
}),
@@ -163,16 +170,24 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
163170

164171
mu.Lock()
165172
defer mu.Unlock()
166-
if len(revokedTokens) != 2 {
167-
t.Fatalf("expected 2 revoke calls, got %d", len(revokedTokens))
173+
if len(revokeCalls) != 2 {
174+
t.Fatalf("expected 2 revoke calls, got %d", len(revokeCalls))
168175
}
169176
// Revocations run concurrently, so order is non-deterministic.
170-
got := map[string]bool{revokedTokens[0]: true, revokedTokens[1]: true}
171-
if !got["refresh-456"] {
172-
t.Errorf("expected refresh token to be revoked, got %v", revokedTokens)
177+
// Build a map from token to its type hint for assertion.
178+
hintByToken := make(map[string]string, len(revokeCalls))
179+
for _, c := range revokeCalls {
180+
hintByToken[c.token] = c.tokenTypeHint
181+
}
182+
if hint, ok := hintByToken["refresh-456"]; !ok {
183+
t.Errorf("expected refresh token to be revoked, got %v", revokeCalls)
184+
} else if hint != "refresh_token" {
185+
t.Errorf("refresh token_type_hint: got %q, want %q", hint, "refresh_token")
173186
}
174-
if !got["access-123"] {
175-
t.Errorf("expected access token to be revoked, got %v", revokedTokens)
187+
if hint, ok := hintByToken["access-123"]; !ok {
188+
t.Errorf("expected access token to be revoked, got %v", revokeCalls)
189+
} else if hint != "access_token" {
190+
t.Errorf("access token_type_hint: got %q, want %q", hint, "access_token")
176191
}
177192
})
178193

@@ -269,16 +284,22 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
269284
})
270285

271286
t.Run("only access token no refresh token", func(t *testing.T) {
272-
var revokedTokens []string
273-
var mu sync.Mutex
287+
var (
288+
callCount int
289+
gotToken string
290+
gotTokenTypeHint string
291+
mu sync.Mutex
292+
)
274293
srv := httptest.NewServer(
275294
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
276295
if err := r.ParseForm(); err != nil {
277296
http.Error(w, "bad form", http.StatusBadRequest)
278297
return
279298
}
280299
mu.Lock()
281-
revokedTokens = append(revokedTokens, r.FormValue("token"))
300+
callCount++
301+
gotToken = r.FormValue("token")
302+
gotTokenTypeHint = r.FormValue("token_type_hint")
282303
mu.Unlock()
283304
w.WriteHeader(http.StatusOK)
284305
}),
@@ -319,11 +340,14 @@ func TestRunTokenDelete_ServerRevocation(t *testing.T) {
319340

320341
mu.Lock()
321342
defer mu.Unlock()
322-
if len(revokedTokens) != 1 {
323-
t.Fatalf("expected 1 revoke call (access only), got %d", len(revokedTokens))
343+
if callCount != 1 {
344+
t.Fatalf("expected 1 revoke call (access only), got %d", callCount)
345+
}
346+
if gotToken != "access-only" {
347+
t.Errorf("token: got %q, want %q", gotToken, "access-only")
324348
}
325-
if revokedTokens[0] != "access-only" {
326-
t.Errorf("expected access token, got %q", revokedTokens[0])
349+
if gotTokenTypeHint != "access_token" {
350+
t.Errorf("token_type_hint: got %q, want %q", gotTokenTypeHint, "access_token")
327351
}
328352
})
329353
}

0 commit comments

Comments
 (0)