-
Notifications
You must be signed in to change notification settings - Fork 0
refactor(oauth): fix backoff bug and clean up error handling #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -181,7 +181,8 @@ func initConfig() { | |
| retry.WithHTTPClient(baseHTTPClient), | ||
| ) | ||
| if err != nil { | ||
| panic(fmt.Sprintf("failed to create retry client: %v", err)) | ||
| fmt.Fprintf(os.Stderr, "Error: failed to create retry client: %v\n", err) | ||
| os.Exit(1) | ||
| } | ||
|
|
||
| // Initialize token store based on mode | ||
|
|
@@ -387,7 +388,7 @@ func run(ctx context.Context, d tui.Displayer) error { | |
| // Demonstrate automatic refresh on 401 | ||
| if err := makeAPICallWithAutoRefresh(ctx, &storage, d); err != nil { | ||
| // Check if error is due to expired refresh token | ||
| if err == ErrRefreshTokenExpired { | ||
| if errors.Is(err, ErrRefreshTokenExpired) { | ||
| d.ReAuthRequired() | ||
| storage, err = performDeviceFlow(ctx, d) | ||
| if err != nil { | ||
|
|
@@ -401,7 +402,6 @@ func run(ctx context.Context, d tui.Displayer) error { | |
| d.Fatal(err) | ||
| return err | ||
| } | ||
| d.APICallOK() | ||
| } else { | ||
| d.APICallFailed(err) | ||
| } | ||
|
|
@@ -477,13 +477,13 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error) | |
|
|
||
| // performDeviceFlow performs the OAuth device authorization flow | ||
| func performDeviceFlow(ctx context.Context, d tui.Displayer) (credstore.Token, error) { | ||
| // Only TokenURL and ClientID are used downstream; | ||
| // requestDeviceCode() builds its own request directly. | ||
| config := &oauth2.Config{ | ||
| ClientID: clientID, | ||
| Endpoint: oauth2.Endpoint{ | ||
| DeviceAuthURL: serverURL + endpointDeviceCode, | ||
| TokenURL: serverURL + endpointToken, | ||
| TokenURL: serverURL + endpointToken, | ||
| }, | ||
| Scopes: []string{"read", "write"}, | ||
| } | ||
|
|
||
| // Step 1: Request device code (with retry logic) | ||
|
|
@@ -540,9 +540,8 @@ func pollForTokenWithProgress( | |
| interval = 5 // Default to 5 seconds per RFC 8628 | ||
| } | ||
|
|
||
| // Exponential backoff state | ||
| // Backoff state | ||
| pollInterval := time.Duration(interval) * time.Second | ||
| backoffMultiplier := 1.0 | ||
|
|
||
| pollTicker := time.NewTicker(pollInterval) | ||
| defer pollTicker.Stop() | ||
|
|
@@ -572,12 +571,8 @@ func pollForTokenWithProgress( | |
| continue | ||
|
|
||
| case oauthErrSlowDown: | ||
| // Server requests slower polling - increase interval | ||
| backoffMultiplier *= 1.5 | ||
| pollInterval = min( | ||
| time.Duration(float64(pollInterval)*backoffMultiplier), | ||
| 60*time.Second, | ||
| ) | ||
| // Server requests slower polling - add 5s per RFC 8628 §3.5 | ||
| pollInterval = min(pollInterval+5*time.Second, 60*time.Second) | ||
| pollTicker.Reset(pollInterval) | ||
| d.PollSlowDown(pollInterval) | ||
| continue | ||
|
|
@@ -831,24 +826,26 @@ func makeAPICallWithAutoRefresh( | |
| } | ||
| defer resp.Body.Close() | ||
|
|
||
| // If 401, try to refresh and retry | ||
| // If 401, drain and close the first response body to allow connection reuse, | ||
| // then refresh the token and retry. | ||
| if resp.StatusCode == http.StatusUnauthorized { | ||
| _, _ = io.Copy(io.Discard, resp.Body) | ||
| resp.Body.Close() | ||
|
|
||
|
Comment on lines
+828
to
+833
|
||
| d.AccessTokenRejected() | ||
|
|
||
| newStorage, err := refreshAccessToken(ctx, storage.RefreshToken, d) | ||
| if err != nil { | ||
| // If refresh token is expired, propagate the error to trigger device flow | ||
| if err == ErrRefreshTokenExpired { | ||
| if errors.Is(err, ErrRefreshTokenExpired) { | ||
| return ErrRefreshTokenExpired | ||
| } | ||
| return fmt.Errorf("refresh failed: %w", err) | ||
| } | ||
|
|
||
| // Update storage in memory | ||
| // Note: newStorage has already been saved to disk by refreshAccessToken() | ||
| storage.AccessToken = newStorage.AccessToken | ||
| storage.RefreshToken = newStorage.RefreshToken | ||
| storage.ExpiresAt = newStorage.ExpiresAt | ||
| *storage = newStorage | ||
|
|
||
| d.TokenRefreshedRetrying() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -77,10 +77,10 @@ func TestPollForToken_SlowDown(t *testing.T) { | |||||||||||
| slowDownCount := atomic.Int32{} | ||||||||||||
|
|
||||||||||||
| server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||||||||||||
| attempts.Add(1) | ||||||||||||
| n := attempts.Add(1) | ||||||||||||
|
|
||||||||||||
| // Return slow_down for first 2 attempts | ||||||||||||
| if attempts.Load() <= 2 { | ||||||||||||
| // Return slow_down on the first attempt | ||||||||||||
| if n == 1 { | ||||||||||||
| slowDownCount.Add(1) | ||||||||||||
| w.Header().Set("Content-Type", "application/json") | ||||||||||||
| w.WriteHeader(http.StatusBadRequest) | ||||||||||||
|
|
@@ -91,8 +91,8 @@ func TestPollForToken_SlowDown(t *testing.T) { | |||||||||||
| return | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Return authorization_pending after slow_down | ||||||||||||
| if attempts.Load() < 5 { | ||||||||||||
| // Return authorization_pending on second attempt | ||||||||||||
| if n == 2 { | ||||||||||||
| w.Header().Set("Content-Type", "application/json") | ||||||||||||
| w.WriteHeader(http.StatusBadRequest) | ||||||||||||
| _ = json.NewEncoder(w).Encode(map[string]string{ | ||||||||||||
|
|
@@ -102,7 +102,7 @@ func TestPollForToken_SlowDown(t *testing.T) { | |||||||||||
| return | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Success | ||||||||||||
| // Success on third attempt | ||||||||||||
| w.Header().Set("Content-Type", "application/json") | ||||||||||||
| _ = json.NewEncoder(w).Encode(map[string]any{ | ||||||||||||
| "access_token": testAccessToken, | ||||||||||||
|
|
@@ -125,6 +125,7 @@ func TestPollForToken_SlowDown(t *testing.T) { | |||||||||||
| Interval: 1, // 1 second for testing | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // After 1 slow_down the interval becomes 1+5=6s, so we need ~8s total. | ||||||||||||
| ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) | ||||||||||||
|
||||||||||||
| // After 1 slow_down the interval becomes 1+5=6s, so we need ~8s total. | |
| ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) | |
| // After 1 slow_down the interval becomes 1+5=6s; with an additional authorization_pending | |
| // before success, the third attempt occurs after ~1s + 6s + 6s ≈ 13s, so use a generous timeout. | |
| ctx, cancel := context.WithTimeout(context.Background(), 25*time.Second) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block now implements a linear +5s backoff on
slow_down(per RFC 8628), but the surrounding function-level documentation still mentions “exponential backoff”. Please update the doc/comment wording to match the new behavior so future readers aren’t misled.