Skip to content

Commit e6d3dec

Browse files
authored
refactor: enable context-aware operations and graceful shutdown throughout (#3)
- Refactor functions throughout the codebase to accept and propagate context.Context for improved cancellation support - Add signal-based context initialization for graceful shutdown in main - Move application logic into a run function that returns an exit code for easier control flow - Use net.ListenConfig with context in place of net.Listen to support context-aware socket binding - Update browser open, callback server, token exchange, token verification, and refresh operations to be context-aware - Update all related unit tests to provide context explicitly and use ListenConfig for port binding - Replace plain lock.release calls with error-ignoring variants to avoid unused return value errors - Add error handling for JSON encoding failures in all HTTP test handlers - Introduce a helper function for error condition OAuth device flow tests to reduce duplicate test code - Minor test improvements including commenting, error messages, and use of constants over literals for tokens Signed-off-by: appleboy <appleboy.tw@gmail.com>
1 parent 8149f61 commit e6d3dec

11 files changed

Lines changed: 144 additions & 119 deletions

browser.go

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"os/exec"
67
"runtime"
@@ -9,16 +10,16 @@ import (
910
// openBrowser attempts to open url in the user's default browser.
1011
// Returns an error if launching the browser fails, but callers should
1112
// always print the URL as a fallback regardless of the error.
12-
func openBrowser(url string) error {
13+
func openBrowser(ctx context.Context, url string) error {
1314
var cmd *exec.Cmd
1415

1516
switch runtime.GOOS {
1617
case "darwin":
17-
cmd = exec.Command("open", url)
18+
cmd = exec.CommandContext(ctx, "open", url)
1819
case "windows":
19-
cmd = exec.Command("cmd", "/c", "start", url)
20+
cmd = exec.CommandContext(ctx, "cmd", "/c", "start", url)
2021
default:
21-
cmd = exec.Command("xdg-open", url)
22+
cmd = exec.CommandContext(ctx, "xdg-open", url)
2223
}
2324

2425
if err := cmd.Start(); err != nil {

browser_flow.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
// - (storage, true, nil) on success
1919
// - (nil, false, nil) when openBrowser() fails — caller should fall back to Device Code Flow
2020
// - (nil, false, err) on a hard error (CSRF mismatch, token exchange failure, etc.)
21-
func performBrowserFlow() (*TokenStorage, bool, error) {
21+
func performBrowserFlow(ctx context.Context) (*TokenStorage, bool, error) {
2222
state, err := generateState()
2323
if err != nil {
2424
return nil, false, fmt.Errorf("failed to generate state: %w", err)
@@ -34,7 +34,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
3434
fmt.Println("Step 1: Opening browser for authorization...")
3535
fmt.Printf("\n %s\n\n", authURL)
3636

37-
if err := openBrowser(authURL); err != nil {
37+
if err := openBrowser(ctx, authURL); err != nil {
3838
// Browser failed to open — signal the caller to fall back immediately.
3939
fmt.Printf("Could not open browser: %v\n", err)
4040
return nil, false, nil
@@ -43,7 +43,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
4343
fmt.Println("Browser opened. Please complete authorization in your browser.")
4444
fmt.Printf("Step 2: Waiting for callback on http://localhost:%d/callback ...\n", callbackPort)
4545

46-
code, err := startCallbackServer(callbackPort, state)
46+
code, err := startCallbackServer(ctx, callbackPort, state)
4747
if err != nil {
4848
if errors.Is(err, ErrCallbackTimeout) {
4949
// User opened the browser but didn't complete authorization in time.
@@ -59,7 +59,7 @@ func performBrowserFlow() (*TokenStorage, bool, error) {
5959
fmt.Println("Authorization code received!")
6060

6161
fmt.Println("Step 3: Exchanging authorization code for tokens...")
62-
storage, err := exchangeCode(code, pkce.Verifier)
62+
storage, err := exchangeCode(ctx, code, pkce.Verifier)
6363
if err != nil {
6464
return nil, false, fmt.Errorf("token exchange failed: %w", err)
6565
}
@@ -88,8 +88,8 @@ func buildAuthURL(state string, pkce *PKCEParams) string {
8888
}
8989

9090
// exchangeCode exchanges an authorization code for access + refresh tokens.
91-
func exchangeCode(code, codeVerifier string) (*TokenStorage, error) {
92-
ctx, cancel := context.WithTimeout(context.Background(), tokenExchangeTimeout)
91+
func exchangeCode(ctx context.Context, code, codeVerifier string) (*TokenStorage, error) {
92+
ctx, cancel := context.WithTimeout(ctx, tokenExchangeTimeout)
9393
defer cancel()
9494

9595
data := url.Values{}

callback.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ type callbackResult struct {
3232
// and returns the authorization code (or an error).
3333
//
3434
// The server shuts itself down after the first request.
35-
func startCallbackServer(port int, expectedState string) (string, error) {
35+
func startCallbackServer(ctx context.Context, port int, expectedState string) (string, error) {
3636
resultCh := make(chan callbackResult, 1)
3737

3838
var once sync.Once
@@ -80,7 +80,7 @@ func startCallbackServer(port int, expectedState string) (string, error) {
8080
WriteTimeout: 10 * time.Second,
8181
}
8282

83-
ln, err := net.Listen("tcp", srv.Addr)
83+
ln, err := (&net.ListenConfig{}).Listen(ctx, "tcp", srv.Addr)
8484
if err != nil {
8585
return "", fmt.Errorf("failed to start callback server on port %d: %w", port, err)
8686
}
@@ -90,9 +90,9 @@ func startCallbackServer(port int, expectedState string) (string, error) {
9090
}()
9191

9292
defer func() {
93-
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
93+
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
9494
defer cancel()
95-
_ = srv.Shutdown(ctx)
95+
_ = srv.Shutdown(shutdownCtx)
9696
}()
9797

9898
select {

callback_test.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"io"
67
"net/http"
@@ -11,11 +12,13 @@ import (
1112

1213
// startCallbackServerAsync starts the callback server in a goroutine and
1314
// returns a channel that will receive the authorization code (or error string).
14-
func startCallbackServerAsync(t *testing.T, port int, state string) chan string {
15+
func startCallbackServerAsync(
16+
t *testing.T, ctx context.Context, port int, state string,
17+
) chan string {
1518
t.Helper()
1619
ch := make(chan string, 1)
1720
go func() {
18-
code, err := startCallbackServer(port, state)
21+
code, err := startCallbackServer(ctx, port, state)
1922
if err != nil {
2023
ch <- "ERROR:" + err.Error()
2124
} else {
@@ -31,13 +34,13 @@ func TestCallbackServer_Success(t *testing.T) {
3134
const port = 19101
3235
state := "test-state-success"
3336

34-
ch := startCallbackServerAsync(t, port, state)
37+
ch := startCallbackServerAsync(t, context.Background(), port, state)
3538

3639
callbackURL := fmt.Sprintf(
3740
"http://127.0.0.1:%d/callback?code=mycode123&state=%s",
3841
port, state,
3942
)
40-
resp, err := http.Get(callbackURL) //nolint:noctx
43+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
4144
if err != nil {
4245
t.Fatalf("GET callback failed: %v", err)
4346
}
@@ -65,13 +68,13 @@ func TestCallbackServer_StateMismatch(t *testing.T) {
6568
const port = 19102
6669
state := "expected-state"
6770

68-
ch := startCallbackServerAsync(t, port, state)
71+
ch := startCallbackServerAsync(t, context.Background(), port, state)
6972

7073
callbackURL := fmt.Sprintf(
7174
"http://127.0.0.1:%d/callback?code=mycode&state=wrong-state",
7275
port,
7376
)
74-
resp, err := http.Get(callbackURL) //nolint:noctx
77+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
7578
if err != nil {
7679
t.Fatalf("GET callback failed: %v", err)
7780
}
@@ -96,13 +99,13 @@ func TestCallbackServer_OAuthError(t *testing.T) {
9699
const port = 19103
97100
state := "state-for-error"
98101

99-
ch := startCallbackServerAsync(t, port, state)
102+
ch := startCallbackServerAsync(t, context.Background(), port, state)
100103

101104
callbackURL := fmt.Sprintf(
102105
"http://127.0.0.1:%d/callback?error=access_denied&error_description=User+denied&state=%s",
103106
port, state,
104107
)
105-
resp, err := http.Get(callbackURL) //nolint:noctx
108+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
106109
if err != nil {
107110
t.Fatalf("GET callback failed: %v", err)
108111
}
@@ -130,14 +133,14 @@ func TestCallbackServer_DoubleCallback(t *testing.T) {
130133
const port = 19105
131134
state := "test-state-double"
132135

133-
ch := startCallbackServerAsync(t, port, state)
136+
ch := startCallbackServerAsync(t, context.Background(), port, state)
134137

135138
url := fmt.Sprintf("http://127.0.0.1:%d/callback?code=mycode&state=%s", port, state)
136139

137140
done := make(chan error, 2)
138141
for range 2 {
139142
go func() {
140-
resp, err := http.Get(url) //nolint:noctx
143+
resp, err := http.Get(url) //nolint:noctx,gosec
141144
if err == nil {
142145
resp.Body.Close()
143146
}
@@ -167,13 +170,13 @@ func TestCallbackServer_MissingCode(t *testing.T) {
167170
const port = 19104
168171
state := "state-for-missing-code"
169172

170-
ch := startCallbackServerAsync(t, port, state)
173+
ch := startCallbackServerAsync(t, context.Background(), port, state)
171174

172175
callbackURL := fmt.Sprintf(
173176
"http://127.0.0.1:%d/callback?state=%s",
174177
port, state,
175178
)
176-
resp, err := http.Get(callbackURL) //nolint:noctx
179+
resp, err := http.Get(callbackURL) //nolint:noctx,gosec
177180
if err != nil {
178181
t.Fatalf("GET callback failed: %v", err)
179182
}

detect.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"fmt"
56
"net"
67
"os"
@@ -23,7 +24,7 @@ type BrowserAvailability struct {
2324
// This function never attempts to open a browser itself; it only inspects
2425
// the environment. Callers that pass the check should still handle
2526
// openBrowser() failures as a secondary fallback.
26-
func checkBrowserAvailability(port int) BrowserAvailability {
27+
func checkBrowserAvailability(ctx context.Context, port int) BrowserAvailability {
2728
// Stage 1a: SSH without X11/Wayland forwarding.
2829
// SSH_TTY / SSH_CLIENT / SSH_CONNECTION indicate a remote shell.
2930
// If a display is also present (X11 forwarding), the browser can still open.
@@ -44,7 +45,8 @@ func checkBrowserAvailability(port int) BrowserAvailability {
4445

4546
// Stage 2: Verify the callback port can be bound.
4647
// A busy port means the redirect server cannot start.
47-
ln, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
48+
lc := &net.ListenConfig{}
49+
ln, err := lc.Listen(ctx, "tcp", fmt.Sprintf("127.0.0.1:%d", port))
4850
if err != nil {
4951
return BrowserAvailability{
5052
false,

detect_test.go

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package main
22

33
import (
4+
"context"
45
"net"
56
"testing"
67
)
@@ -12,7 +13,7 @@ func TestCheckBrowserAvailability_SSH_NoDisplay(t *testing.T) {
1213
t.Setenv("DISPLAY", "")
1314
t.Setenv("WAYLAND_DISPLAY", "")
1415

15-
avail := checkBrowserAvailability(18888)
16+
avail := checkBrowserAvailability(context.Background(), 18888)
1617

1718
if avail.Available {
1819
t.Error("expected browser unavailable in SSH session without display")
@@ -29,7 +30,7 @@ func TestCheckBrowserAvailability_SSHClient_NoDisplay(t *testing.T) {
2930
t.Setenv("DISPLAY", "")
3031
t.Setenv("WAYLAND_DISPLAY", "")
3132

32-
avail := checkBrowserAvailability(18888)
33+
avail := checkBrowserAvailability(context.Background(), 18888)
3334

3435
if avail.Available {
3536
t.Error("expected browser unavailable when SSH_CLIENT set and no display")
@@ -43,7 +44,7 @@ func TestCheckBrowserAvailability_SSHConnection_NoDisplay(t *testing.T) {
4344
t.Setenv("DISPLAY", "")
4445
t.Setenv("WAYLAND_DISPLAY", "")
4546

46-
avail := checkBrowserAvailability(18888)
47+
avail := checkBrowserAvailability(context.Background(), 18888)
4748

4849
if avail.Available {
4950
t.Error("expected browser unavailable when SSH_CONNECTION set and no display")
@@ -58,14 +59,14 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) {
5859

5960
// Use a port that is definitely free (bind to :0 and get the port,
6061
// then close it; the brief gap is acceptable for a unit test).
61-
ln, err := net.Listen("tcp", "127.0.0.1:0")
62+
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
6263
if err != nil {
6364
t.Skip("cannot allocate test port")
6465
}
6566
port := ln.Addr().(*net.TCPAddr).Port
6667
ln.Close()
6768

68-
avail := checkBrowserAvailability(port)
69+
avail := checkBrowserAvailability(context.Background(), port)
6970

7071
// X11 forwarding over SSH should be detected as browser-capable
7172
// (DISPLAY is set, port is free).
@@ -76,7 +77,7 @@ func TestCheckBrowserAvailability_SSH_WithX11(t *testing.T) {
7677

7778
func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {
7879
// Bind a port and keep it busy during the test.
79-
ln, err := net.Listen("tcp", "127.0.0.1:0")
80+
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
8081
if err != nil {
8182
t.Skip("cannot bind test port")
8283
}
@@ -92,7 +93,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {
9293
t.Setenv("DISPLAY", ":0")
9394
t.Setenv("WAYLAND_DISPLAY", "")
9495

95-
avail := checkBrowserAvailability(port)
96+
avail := checkBrowserAvailability(context.Background(), port)
9697

9798
if avail.Available {
9899
t.Errorf("expected browser unavailable when port %d is busy", port)
@@ -104,7 +105,7 @@ func TestCheckBrowserAvailability_PortUnavailable(t *testing.T) {
104105

105106
func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
106107
// Find a free port.
107-
ln, err := net.Listen("tcp", "127.0.0.1:0")
108+
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
108109
if err != nil {
109110
t.Skip("cannot allocate test port")
110111
}
@@ -117,7 +118,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
117118
t.Setenv("DISPLAY", ":0")
118119
t.Setenv("WAYLAND_DISPLAY", "")
119120

120-
avail := checkBrowserAvailability(port)
121+
avail := checkBrowserAvailability(context.Background(), port)
121122

122123
if !avail.Available {
123124
t.Errorf(
@@ -128,7 +129,7 @@ func TestCheckBrowserAvailability_PortAvailable(t *testing.T) {
128129
}
129130

130131
func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) {
131-
ln, err := net.Listen("tcp", "127.0.0.1:0")
132+
ln, err := (&net.ListenConfig{}).Listen(context.Background(), "tcp", "127.0.0.1:0")
132133
if err != nil {
133134
t.Skip("cannot allocate test port")
134135
}
@@ -141,7 +142,7 @@ func TestCheckBrowserAvailability_ReasonIsEmptyWhenAvailable(t *testing.T) {
141142
t.Setenv("DISPLAY", ":0")
142143
t.Setenv("WAYLAND_DISPLAY", "")
143144

144-
avail := checkBrowserAvailability(port)
145+
avail := checkBrowserAvailability(context.Background(), port)
145146

146147
if avail.Available && avail.Reason != "" {
147148
t.Errorf("expected empty reason when browser is available, got: %s", avail.Reason)

filelock_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func TestConcurrentLocks(t *testing.T) {
6262
concurrent--
6363
mu.Unlock()
6464

65-
lock.release()
65+
_ = lock.release()
6666
}(i)
6767
}
6868

@@ -89,5 +89,5 @@ func TestStaleLockRemoval(t *testing.T) {
8989
if err != nil {
9090
t.Fatalf("acquireFileLock() with stale lock: %v", err)
9191
}
92-
lock.release()
92+
_ = lock.release()
9393
}

0 commit comments

Comments
 (0)