Skip to content

Commit 7994db9

Browse files
appleboyclaude
andcommitted
refactor(oauth): fix backoff bug and clean up error handling
- Use additive 5s backoff per RFC 8628 §3.5 instead of compounding multiplier - Replace direct error equality with errors.Is() for ErrRefreshTokenExpired - Remove duplicate APICallOK() call on re-auth retry path - Use full struct assignment instead of partial field copies in auto-refresh - Drain response body before 401 retry to allow connection pool reuse - Remove unused DeviceAuthURL and Scopes from oauth2.Config - Replace panic() with fmt.Fprintf + os.Exit(1) in initConfig() - Replace custom contains() helpers with strings.Contains() in tests - Update slow_down test for additive backoff behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent bfe4d22 commit 7994db9

3 files changed

Lines changed: 30 additions & 46 deletions

File tree

main.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ func initConfig() {
181181
retry.WithHTTPClient(baseHTTPClient),
182182
)
183183
if err != nil {
184-
panic(fmt.Sprintf("failed to create retry client: %v", err))
184+
fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err)
185+
os.Exit(1)
185186
}
186187

187188
// Initialize token store based on mode
@@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error {
387388
// Demonstrate automatic refresh on 401
388389
if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil {
389390
// Check if error is due to expired refresh token
390-
if err == ErrRefreshTokenExpired {
391+
if errors.Is(err, ErrRefreshTokenExpired) {
391392
d.ReAuthRequired()
392393
storage, err = performDeviceFlow(ctx, d)
393394
if err != nil {
@@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error {
401402
d.Fatal(err)
402403
return err
403404
}
404-
d.APICallOK()
405405
} else {
406406
d.APICallFailed(err)
407407
}
@@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
477477

478478
// performDeviceFlow performs the OAuth device authorization flow
479479
func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) {
480+
// Only TokenURL and ClientID are used downstream;
481+
// requestDeviceCode() builds its own request directly.
480482
config := &oauth2.Config{
481483
ClientID: clientID,
482484
Endpoint: oauth2.Endpoint{
483-
DeviceAuthURL: serverURL + endpointDeviceCode,
484-
TokenURL: serverURL + endpointToken,
485+
TokenURL: serverURL + endpointToken,
485486
},
486-
Scopes: []string{"read", "write"},
487487
}
488488

489489
// Step 1: Request device code (with retry logic)
@@ -540,9 +540,8 @@ func pollForTokenWithProgress(
540540
interval = 5 // Default to 5 seconds per RFC 8628
541541
}
542542

543-
// Exponential backoff state
543+
// Backoff state
544544
pollInterval := time.Duration(interval) * time.Second
545-
backoffMultiplier := 1.0
546545

547546
pollTicker := time.NewTicker(pollInterval)
548547
defer pollTicker.Stop()
@@ -572,12 +571,8 @@ func pollForTokenWithProgress(
572571
continue
573572

574573
case oauthErrSlowDown:
575-
// Server requests slower polling - increase interval
576-
backoffMultiplier *= 1.5
577-
pollInterval = min(
578-
time.Duration(float64(pollInterval)*backoffMultiplier),
579-
60*time.Second,
580-
)
574+
// Server requests slower polling - add 5s per RFC 8628 §3.5
575+
pollInterval = min(pollInterval+5*time.Second, 60*time.Second)
581576
pollTicker.Reset(pollInterval)
582577
d.PollSlowDown(pollInterval)
583578
continue
@@ -831,24 +826,26 @@ func makeAPICallWithAutoRefresh(
831826
}
832827
defer resp.Body.Close()
833828

834-
// If 401, try to refresh and retry
829+
// If 401, drain and close the first response body to allow connection reuse,
830+
// then refresh the token and retry.
835831
if resp.StatusCode == http.StatusUnauthorized {
832+
_, _ = io.Copy(io.Discard, resp.Body)
833+
resp.Body.Close()
834+
836835
d.AccessTokenRejected()
837836

838837
newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d)
839838
if err != nil {
840839
// If refresh token is expired, propagate the error to trigger device flow
841-
if err == ErrRefreshTokenExpired {
840+
if errors.Is(err, ErrRefreshTokenExpired) {
842841
return ErrRefreshTokenExpired
843842
}
844843
return fmt.Errorf("refresh failed: %w", err)
845844
}
846845

847846
// Update storage in memory
848847
// Note: newStorage has already been saved to disk by refreshAccessToken()
849-
storage.AccessToken = newStorage.AccessToken
850-
storage.RefreshToken = newStorage.RefreshToken
851-
storage.ExpiresAt = newStorage.ExpiresAt
848+
*storage = newStorage
852849

853850
d.TokenRefreshedRetrying()
854851

main_test.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http/httptest"
99
"os"
1010
"path/filepath"
11+
"strings"
1112
"sync"
1213
"sync/atomic"
1314
"testing"
@@ -271,7 +272,7 @@ func TestValidateTokenResponse(t *testing.T) {
271272
t.Errorf("validateTokenResponse() expected error but got nil")
272273
return
273274
}
274-
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
275+
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
275276
t.Errorf(
276277
"validateTokenResponse() error = %v, want error containing %q",
277278
err,
@@ -285,21 +286,6 @@ func TestValidateTokenResponse(t *testing.T) {
285286
}
286287
}
287288

288-
// contains checks if string s contains substr
289-
func contains(s, substr string) bool {
290-
return len(s) >= len(substr) && (s == substr || len(substr) == 0 ||
291-
(len(s) > 0 && len(substr) > 0 && stringContains(s, substr)))
292-
}
293-
294-
func stringContains(s, substr string) bool {
295-
for i := 0; i <= len(s)-len(substr); i++ {
296-
if s[i:i+len(substr)] == substr {
297-
return true
298-
}
299-
}
300-
return false
301-
}
302-
303289
func TestRefreshAccessToken_RotationMode(t *testing.T) {
304290
// Save original values
305291
origServerURL := serverURL
@@ -526,7 +512,7 @@ func TestRefreshAccessToken_ValidationErrors(t *testing.T) {
526512
t.Errorf("refreshAccessToken() expected error but got nil")
527513
return
528514
}
529-
if tt.errContains != "" && !contains(err.Error(), tt.errContains) {
515+
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
530516
t.Errorf(
531517
"refreshAccessToken() error = %v, want error containing %q",
532518
err,

polling_test.go

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) {
7777
slowDownCount := atomic.Int32{}
7878

7979
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80-
attempts.Add(1)
80+
n := attempts.Add(1)
8181

82-
// Return slow_down for first 2 attempts
83-
if attempts.Load() <= 2 {
82+
// Return slow_down on the first attempt
83+
if n == 1 {
8484
slowDownCount.Add(1)
8585
w.Header().Set("Content-Type", "application/json")
8686
w.WriteHeader(http.StatusBadRequest)
@@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) {
9191
return
9292
}
9393

94-
// Return authorization_pending after slow_down
95-
if attempts.Load() < 5 {
94+
// Return authorization_pending on second attempt
95+
if n == 2 {
9696
w.Header().Set("Content-Type", "application/json")
9797
w.WriteHeader(http.StatusBadRequest)
9898
_ = json.NewEncoder(w).Encode(map[string]string{
@@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
102102
return
103103
}
104104

105-
// Success
105+
// Success on third attempt
106106
w.Header().Set("Content-Type", "application/json")
107107
_ = json.NewEncoder(w).Encode(map[string]any{
108108
"access_token": testAccessToken,
@@ -125,6 +125,7 @@ func TestPollForToken_SlowDown(t *testing.T) {
125125
Interval: 1, // 1 second for testing
126126
}
127127

128+
// After 1 slow_down the interval becomes 1+5=6s, so we need ~8s total.
128129
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
129130
defer cancel()
130131

@@ -137,14 +138,14 @@ func TestPollForToken_SlowDown(t *testing.T) {
137138
t.Errorf("Expected access token 'test-access-token', got '%s'", token.AccessToken)
138139
}
139140

140-
if slowDownCount.Load() < 2 {
141-
t.Errorf("Expected at least 2 slow_down responses, got %d", slowDownCount.Load())
141+
if slowDownCount.Load() < 1 {
142+
t.Errorf("Expected at least 1 slow_down response, got %d", slowDownCount.Load())
142143
}
143144

144145
// Verify that polling continued after slow_down
145-
if attempts.Load() < 5 {
146+
if attempts.Load() < 3 {
146147
t.Errorf(
147-
"Expected at least 5 attempts (2 slow_down + 2 pending + 1 success), got %d",
148+
"Expected at least 3 attempts (1 slow_down + 1 pending + 1 success), got %d",
148149
attempts.Load(),
149150
)
150151
}

0 commit comments

Comments
 (0)