Skip to content

Commit b1def4f

Browse files
appleboyclaude
andauthored
refactor(oauth): fix backoff bug and clean up error handling (#17)
* refactor(oauth): fix backoff bug and clean up error handling - Use additive 5s backoff per RFC 8628 §3.5 instead of compounding multiplier - Replace direct error equality with errors.Is() for ErrRefreshTokenExpired - Remove duplicate APICallOK() call on re-auth retry path - Use full struct assignment instead of partial field copies in auto-refresh - Drain response body before 401 retry to allow connection pool reuse - Remove unused DeviceAuthURL and Scopes from oauth2.Config - Replace panic() with fmt.Fprintf + os.Exit(1) in initConfig() - Replace custom contains() helpers with strings.Contains() in tests - Update slow_down test for additive backoff behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(review): address PR feedback from Copilot review - Fix double-close of response body by using explicit close in 401 path and defer in else branch instead of unconditional defer - Increase slow_down test timeout from 15s to 25s to prevent flakiness - Update pollForTokenWithProgress doc comment to say "additive backoff" instead of "exponential backoff" Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bfe4d22 commit b1def4f

3 files changed

Lines changed: 35 additions & 49 deletions

File tree

main.go

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ func initConfig() {
181181
retry.WithHTTPClient(baseHTTPClient),
182182
)
183183
if err != nil {
184-
panic(fmt.Sprintf("failed to create retry client: %v", err))
184+
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
185+
os.Exit(1)
185186
}
186187

187188
// Initialize token store based on mode
@@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error {
387388
// Demonstrate automatic refresh on 401
388389
if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil {
389390
// Check if error is due to expired refresh token
390-
if err == ErrRefreshTokenExpired {
391+
if errors.Is(err, ErrRefreshTokenExpired) {
391392
d.ReAuthRequired()
392393
storage, err = performDeviceFlow(ctx, d)
393394
if err != nil {
@@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error {
401402
d.Fatal(err)
402403
return err
403404
}
404-
d.APICallOK()
405405
} else {
406406
d.APICallFailed(err)
407407
}
@@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
477477

478478
// performDeviceFlow performs the OAuth device authorization flow
479479
func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) {
480+
// Only TokenURL and ClientID are used downstream;
481+
// requestDeviceCode() builds its own request directly.
480482
config := &oauth2.Config{
481483
ClientID: clientID,
482484
Endpoint: oauth2.Endpoint{
483-
DeviceAuthURL: serverURL + endpointDeviceCode,
484-
TokenURL: serverURL + endpointToken,
485+
TokenURL: serverURL + endpointToken,
485486
},
486-
Scopes: []string{"read", "write"},
487487
}
488488

489489
// Step 1: Request device code (with retry logic)
@@ -527,7 +527,7 @@ func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, e
527527
}
528528

529529
// pollForTokenWithProgress polls for token while reporting progress via Displayer.
530-
// Implements exponential backoff for slow_down errors per RFC 8628.
530+
// Implements additive backoff (+5s) for slow_down errors per RFC 8628 §3.5.
531531
func pollForTokenWithProgress(
532532
ctx context.Context,
533533
config *oauth2.Config,
@@ -540,9 +540,8 @@ func pollForTokenWithProgress(
540540
interval = 5 // Default to 5 seconds per RFC 8628
541541
}
542542

543-
// Exponential backoff state
543+
// Backoff state
544544
pollInterval := time.Duration(interval) * time.Second
545-
backoffMultiplier := 1.0
546545

547546
pollTicker := time.NewTicker(pollInterval)
548547
defer pollTicker.Stop()
@@ -572,12 +571,8 @@ func pollForTokenWithProgress(
572571
continue
573572

574573
case oauthErrSlowDown:
575-
// Server requests slower polling - increase interval
576-
backoffMultiplier *= 1.5
577-
pollInterval = min(
578-
time.Duration(float64(pollInterval)*backoffMultiplier),
579-
60*time.Second,
580-
)
574+
// Server requests slower polling - add 5s per RFC 8628 §3.5
575+
pollInterval = min(pollInterval+5*time.Second, 60*time.Second)
581576
pollTicker.Reset(pollInterval)
582577
d.PollSlowDown(pollInterval)
583578
continue
@@ -829,26 +824,27 @@ func makeAPICallWithAutoRefresh(
829824
if err != nil {
830825
return fmt.Errorf("API request failed: %w", err)
831826
}
832-
defer resp.Body.Close()
833827

834-
// If 401, try to refresh and retry
828+
// If 401, drain and close the first response body to allow connection reuse,
829+
// then refresh the token and retry.
835830
if resp.StatusCode == http.StatusUnauthorized {
831+
_, _ = io.Copy(io.Discard, resp.Body)
832+
resp.Body.Close()
833+
836834
d.AccessTokenRejected()
837835

838836
newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d)
839837
if err != nil {
840838
// If refresh token is expired, propagate the error to trigger device flow
841-
if err == ErrRefreshTokenExpired {
839+
if errors.Is(err, ErrRefreshTokenExpired) {
842840
return ErrRefreshTokenExpired
843841
}
844842
return fmt.Errorf("refresh failed: %w", err)
845843
}
846844

847845
// Update storage in memory
848846
// Note: newStorage has already been saved to disk by refreshAccessToken()
849-
storage.AccessToken = newStorage.AccessToken
850-
storage.RefreshToken = newStorage.RefreshToken
851-
storage.ExpiresAt = newStorage.ExpiresAt
847+
*storage = newStorage
852848

853849
d.TokenRefreshedRetrying()
854850

@@ -869,6 +865,8 @@ func makeAPICallWithAutoRefresh(
869865
return fmt.Errorf("retry failed: %w", err)
870866
}
871867
defer resp.Body.Close()
868+
} else {
869+
defer resp.Body.Close()
872870
}
873871

874872
body, err := io.ReadAll(resp.Body)

main_test.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http/httptest"
99
"os"
1010
"path/filepath"
11+
"strings"
1112
"sync"
1213
"sync/atomic"
1314
"testing"
@@ -271,7 +272,7 @@ func TestValidateTokenResponse(t *testing.T) {
271272
t.Errorf("validateTokenResponse() expected error but got nil")
272273
return
273274
}
274-
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
275+
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
275276
t.Errorf(
276277
"validateTokenResponse() error = %v, want error containing %q",
277278
err,
@@ -285,21 +286,6 @@ func TestValidateTokenResponse(t *testing.T) {
285286
}
286287
}
287288

288-
// contains checks if string s contains substr
289-
func contains(s, substr string) bool {
290-
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
291-
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
292-
}
293-
294-
func stringContains(s, substr string) bool {
295-
for i := 0; i <= len(s)-len(substr); i++ {
296-
if s[i:i+len(substr)] == substr {
297-
return true
298-
}
299-
}
300-
return false
301-
}
302-
303289
func TestRefreshAccessToken_RotationMode(t *testing.T) {
304290
// Save original values
305291
origServerURL := serverURL
@@ -526,7 +512,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) {
526512
t.Errorf("refreshAccessToken() expected error but got nil")
527513
return
528514
}
529-
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
515+
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
530516
t.Errorf(
531517
"refreshAccessToken() error = %v, want error containing %q",
532518
err,

polling_test.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) {
7777
slowDownCount := atomic.Int32{}
7878

7979
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80-
attempts.Add(1)
80+
n := attempts.Add(1)
8181

82-
// Return slow_down for first 2 attempts
83-
if attempts.Load() <= 2 {
82+
// Return slow_down on the first attempt
83+
if n == 1 {
8484
slowDownCount.Add(1)
8585
w.Header().Set("Content-Type", "application/json")
8686
w.WriteHeader(http.StatusBadRequest)
@@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) {
9191
return
9292
}
9393

94-
// Return authorization_pending after slow_down
95-
if attempts.Load() < 5 {
94+
// Return authorization_pending on second attempt
95+
if n == 2 {
9696
w.Header().Set("Content-Type", "application/json")
9797
w.WriteHeader(http.StatusBadRequest)
9898
_ = json.NewEncoder(w).Encode(map[string]string{
@@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
102102
return
103103
}
104104

105-
// Success
105+
// Success on third attempt
106106
w.Header().Set("Content-Type", "application/json")
107107
_ = json.NewEncoder(w).Encode(map[string]any{
108108
"access_token": testAccessToken,
@@ -125,7 +125,9 @@ func TestPollForToken_SlowDown(t *testing.T) {
125125
Interval: 1, // 1 second for testing
126126
}
127127

128-
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
128+
// After 1 slow_down the interval becomes 1+5=6s; with an additional authorization_pending
129+
// before success, the third attempt occurs after ~1s + 6s + 6s ≈ 13s, so use a generous timeout.
130+
ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second)
129131
defer cancel()
130132

131133
token, err := pollForTokenWithProgress(ctx, config, deviceAuth, tui.NoopDisplayer{})
@@ -137,14 +139,14 @@ func TestPollForToken_SlowDown(t *testing.T) {
137139
t.Errorf("Expected access token 'test-access-token', got '%s'", token.AccessToken)
138140
}
139141

140-
if slowDownCount.Load() < 2 {
141-
t.Errorf("Expected at least 2 slow_down responses, got %d", slowDownCount.Load())
142+
if slowDownCount.Load() < 1 {
143+
t.Errorf("Expected at least 1 slow_down response, got %d", slowDownCount.Load())
142144
}
143145

144146
// Verify that polling continued after slow_down
145-
if attempts.Load() < 5 {
147+
if attempts.Load() < 3 {
146148
t.Errorf(
147-
"Expected at least 5 attempts (2 slow_down + 2 pending + 1 success), got %d",
149+
"Expected at least 3 attempts (1 slow_down + 1 pending + 1 success), got %d",
148150
attempts.Load(),
149151
)
150152
}

0 commit comments

Comments
 (0)