Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ func initConfig() {
retry.WithHTTPClient(baseHTTPClient),
)
if err != nil {
panic(fmt.Sprintf("failed to create retry client: %v", err))
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
os.Exit(1)
}

// Initialize token store based on mode
Expand Down Expand Up @@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error {
// Demonstrate automatic refresh on 401
if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil {
// Check if error is due to expired refresh token
if err == ErrRefreshTokenExpired {
if errors.Is(err, ErrRefreshTokenExpired) {
d.ReAuthRequired()
storage, err = performDeviceFlow(ctx, d)
if err != nil {
Expand All @@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error {
d.Fatal(err)
return err
}
d.APICallOK()
} else {
d.APICallFailed(err)
}
Expand Down Expand Up @@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)

// performDeviceFlow performs the OAuth device authorization flow
func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) {
// Only TokenURL and ClientID are used downstream;
// requestDeviceCode() builds its own request directly.
config := &oauth2.Config{
ClientID: clientID,
Endpoint: oauth2.Endpoint{
DeviceAuthURL: serverURL + endpointDeviceCode,
TokenURL: serverURL + endpointToken,
TokenURL: serverURL + endpointToken,
},
Scopes: []string{"read", "write"},
}

// Step 1: Request device code (with retry logic)
Expand Down Expand Up @@ -540,9 +540,8 @@ func pollForTokenWithProgress(
interval = 5 // Default to 5 seconds per RFC 8628
}

// Exponential backoff state
// Backoff state
pollInterval := time.Duration(interval) * time.Second
backoffMultiplier := 1.0

Comment on lines +543 to 545
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block now implements a linear +5s backoff on slow_down (per RFC 8628), but the surrounding function-level documentation still mentions “exponential backoff”. Please update the doc/comment wording to match the new behavior so future readers aren’t misled.

Copilot uses AI. Check for mistakes.
pollTicker := time.NewTicker(pollInterval)
defer pollTicker.Stop()
Expand Down Expand Up @@ -572,12 +571,8 @@ func pollForTokenWithProgress(
continue

case oauthErrSlowDown:
// Server requests slower polling - increase interval
backoffMultiplier *= 1.5
pollInterval = min(
time.Duration(float64(pollInterval)*backoffMultiplier),
60*time.Second,
)
// Server requests slower polling - add 5s per RFC 8628 §3.5
pollInterval = min(pollInterval+5*time.Second, 60*time.Second)
pollTicker.Reset(pollInterval)
d.PollSlowDown(pollInterval)
continue
Expand Down Expand Up @@ -831,24 +826,26 @@ func makeAPICallWithAutoRefresh(
}
defer resp.Body.Close()

// If 401, try to refresh and retry
// If 401, drain and close the first response body to allow connection reuse,
// then refresh the token and retry.
if resp.StatusCode == http.StatusUnauthorized {
_, _ = io.Copy(io.Discard, resp.Body)
resp.Body.Close()

Comment on lines +828 to +833
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resp.Body.Close() is deferred earlier, but the 401 branch also explicitly closes the body after draining it. That can lead to a double-close of the original response body. Consider restructuring so the initial defer resp.Body.Close() is only set for the non-401 path (or removed and replaced with explicit closes) while still draining+closing before the retry to allow connection reuse.

Copilot uses AI. Check for mistakes.
d.AccessTokenRejected()

newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d)
if err != nil {
// If refresh token is expired, propagate the error to trigger device flow
if err == ErrRefreshTokenExpired {
if errors.Is(err, ErrRefreshTokenExpired) {
return ErrRefreshTokenExpired
}
return fmt.Errorf("refresh failed: %w", err)
}

// Update storage in memory
// Note: newStorage has already been saved to disk by refreshAccessToken()
storage.AccessToken = newStorage.AccessToken
storage.RefreshToken = newStorage.RefreshToken
storage.ExpiresAt = newStorage.ExpiresAt
*storage = newStorage

d.TokenRefreshedRetrying()

Expand Down
20 changes: 3 additions & 17 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -271,7 +272,7 @@ func TestValidateTokenResponse(t *testing.T) {
t.Errorf("validateTokenResponse() expected error but got nil")
return
}
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf(
"validateTokenResponse() error = %v, want error containing %q",
err,
Expand All @@ -285,21 +286,6 @@ func TestValidateTokenResponse(t *testing.T) {
}
}

// contains checks if string s contains substr
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
}

func stringContains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

func TestRefreshAccessToken_RotationMode(t *testing.T) {
// Save original values
origServerURL := serverURL
Expand Down Expand Up @@ -526,7 +512,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) {
t.Errorf("refreshAccessToken() expected error but got nil")
return
}
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
t.Errorf(
"refreshAccessToken() error = %v, want error containing %q",
err,
Expand Down
21 changes: 11 additions & 10 deletions polling_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) {
slowDownCount := atomic.Int32{}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attempts.Add(1)
n := attempts.Add(1)

// Return slow_down for first 2 attempts
if attempts.Load() <= 2 {
// Return slow_down on the first attempt
if n == 1 {
slowDownCount.Add(1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
Expand All @@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) {
return
}

// Return authorization_pending after slow_down
if attempts.Load() < 5 {
// Return authorization_pending on second attempt
if n == 2 {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(w).Encode(map[string]string{
Expand All @@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
return
}

// Success
// Success on third attempt
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{
"access_token": testAccessToken,
Expand All @@ -125,6 +125,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
Interval: 1, // 1 second for testing
}

// After 1 slow_down the interval becomes 1+5=6s, so we need ~8s total.
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
Copy link

Copilot AI Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The timing estimate here looks incorrect and makes the test timeout very tight (potentially flaky). With Interval=1s, after the first slow_down the poll interval becomes 6s; with one authorization_pending before success, the third attempt won’t occur until ~1s + 6s + 6s ≈ 13s (plus overhead). Consider either increasing the context timeout buffer, or adjusting the mocked server sequence so success happens on the second attempt after slow_down if you want a shorter test.

Suggested change
// After 1 slow_down the interval becomes 1+5=6s, so we need ~8s total.
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
// After 1 slow_down the interval becomes 1+5=6s; with an additional authorization_pending
// before success, the third attempt occurs after ~1s + 6s + 6s ≈ 13s, so use a generous timeout.
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)

Copilot uses AI. Check for mistakes.
defer cancel()

Expand All @@ -137,14 +138,14 @@ func TestPollForToken_SlowDown(t *testing.T) {
t.Errorf("Expected access token 'test-access-token', got '%s'", token.AccessToken)
}

if slowDownCount.Load() < 2 {
t.Errorf("Expected at least 2 slow_down responses, got %d", slowDownCount.Load())
if slowDownCount.Load() < 1 {
t.Errorf("Expected at least 1 slow_down response, got %d", slowDownCount.Load())
}

// Verify that polling continued after slow_down
if attempts.Load() < 5 {
if attempts.Load() < 3 {
t.Errorf(
"Expected at least 5 attempts (2 slow_down + 2 pending + 1 success), got %d",
"Expected at least 3 attempts (1 slow_down + 1 pending + 1 success), got %d",
attempts.Load(),
)
}
Expand Down
Loading