Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -119,7 +120,7 @@ func TestValidateAPIKey_Forbidden(t *testing.T) {

func TestValidateAPIKey_ServerError(t *testing.T) {
origSleep := retrySleepFn
retrySleepFn = func(d time.Duration) {} // no-op for fast tests
retrySleepFn = func(context.Context, time.Duration) error { return nil } // no-op for fast tests
defer func() { retrySleepFn = origSleep }()

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -164,7 +165,7 @@ func TestValidateAPIKey_EmptyResults(t *testing.T) {

func TestValidateAPIKey_ConnectionError(t *testing.T) {
origSleep := retrySleepFn
retrySleepFn = func(d time.Duration) {} // no-op for fast tests
retrySleepFn = func(context.Context, time.Duration) error { return nil } // no-op for fast tests
defer func() { retrySleepFn = origSleep }()

// Use a non-existent URL to simulate connection failure.
Expand Down
6 changes: 5 additions & 1 deletion internal/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ type bearerAuthTransport struct {
}

func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenCache.Token()
// Propagate the request context to the token fetch. KLA-448 — the
// previous implementation called Token() with no context, so a
// short --probe-timeout couldn't reach the OAuth endpoint and the
// caller waited through the http.Client's 30s default.
token, err := t.tokenCache.Token(req.Context())
if err != nil {
return nil, fmt.Errorf("failed to obtain bearer token: %w", err)
}
Expand Down
3 changes: 2 additions & 1 deletion internal/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"bytes"
"context"
"crypto/tls"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -436,7 +437,7 @@ func TestLoggingTransport_VerboseLogsError(t *testing.T) {
defer resetViper()

origSleep := retrySleepFn
retrySleepFn = func(d time.Duration) {}
retrySleepFn = func(context.Context, time.Duration) error { return nil }
defer func() { retrySleepFn = origSleep }()

viper.Set("verbose", true)
Expand Down
146 changes: 146 additions & 0 deletions internal/api/context_propagation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package api

import (
"context"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)

// Pre-KLA-448 these scenarios would all wait through internal timeouts
// (30s OAuth http.Client, uninterruptible backoff sleep). Post-fix the
// context cancels at every layer.

// TestTokenCache_TokenRespectsContextTimeout drives the OAuth token
// endpoint with an httptest server that responds slowly, then calls
// Token with a 100ms context. The call must return within ~2s (the
// server's slow-respond cap) — pre-fix it would wait up to 30s (the
// http.Client's default Timeout).
func TestTokenCache_TokenRespectsContextTimeout(t *testing.T) {
slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 2s is enough to outlast the 100ms ctx deadline + reasonable
// slack but short enough that the test process won't hang if
// cancellation breaks.
select {
case <-time.After(2 * time.Second):
case <-r.Context().Done():
}
}))
defer slow.Close()

prev := SetOAuthTokenURL(slow.URL)
defer SetOAuthTokenURL(prev)

tc := NewTokenCache("client-id", "client-secret")

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

start := time.Now()
_, err := tc.Token(ctx)
elapsed := time.Since(start)

if err == nil {
t.Fatal("expected error from cancelled context, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("err chain should contain context.DeadlineExceeded so errors.Is callers can handle it; got: %v", err)
}
if elapsed > 2*time.Second {
t.Errorf("Token returned in %v, want fast cancel; pre-fix this would wait up to 30s (server slow-respond was 2s)", elapsed)
}
}

// TestTokenCache_TokenRespectsContextCancellation tests the
// Ctrl-C / parent-cancel path. Identical structure to the timeout
// case but the cancellation is explicit.
func TestTokenCache_TokenRespectsContextCancellation(t *testing.T) {
slow := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
select {
case <-time.After(2 * time.Second):
case <-r.Context().Done():
}
}))
defer slow.Close()
hung := slow // keep variable name for body below

prev := SetOAuthTokenURL(hung.URL)
defer SetOAuthTokenURL(prev)

tc := NewTokenCache("client-id", "client-secret")

ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()

_, err := tc.Token(ctx)
if !errors.Is(err, context.Canceled) {
t.Errorf("err chain should contain context.Canceled; got: %v", err)
}
}

// TestRetryTransport_BackoffRespectsContext verifies the retry loop's
// backoff sleep no longer blocks past a context deadline. Without the
// fix, a 1s retry backoff would run to completion even with a 50ms
// context — making `jc doctor --probe-timeout 100ms` regularly take
// seconds.
func TestRetryTransport_BackoffRespectsContext(t *testing.T) {
var attemptCount atomic.Int32
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
attemptCount.Add(1)
w.WriteHeader(http.StatusServiceUnavailable) // 503 → retryable
}))
defer ts.Close()

rt := newRetryTransport(http.DefaultTransport)
client := &http.Client{Transport: rt}

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()

req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL, nil)

start := time.Now()
resp, err := client.Do(req)
elapsed := time.Since(start)
if resp != nil {
resp.Body.Close()
}

if err == nil {
t.Fatal("expected error when context cancels during backoff, got nil response")
}
// Should fail before the 1s backoff that follows the first retry.
if elapsed > 1*time.Second {
t.Errorf("retry loop ran for %v after context expired; expected ctx-cancel during backoff to short-circuit", elapsed)
}
}

// TestRetryTransport_PreservesContextErrorInChain pins that the
// retry transport's error return is errors.Is-friendly for
// context.DeadlineExceeded. doctor's classifyProbeError depends on
// this for its "timeout" classification.
func TestRetryTransport_PreservesContextErrorInChain(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer ts.Close()

rt := newRetryTransport(http.DefaultTransport)
client := &http.Client{Transport: rt}

ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancel()
time.Sleep(20 * time.Millisecond) // ensure ctx fires before the request

req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL, nil)
_, err := client.Do(req)
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("err chain should contain context.DeadlineExceeded; got: %v", err)
}
}
29 changes: 23 additions & 6 deletions internal/api/oauth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -58,8 +59,19 @@ func NewTokenCache(clientID, clientSecret string) *TokenCache {
}
}

// Token returns a valid bearer token, refreshing if expired or not yet fetched.
func (tc *TokenCache) Token() (string, error) {
// Token returns a valid bearer token, refreshing if expired or not yet
// fetched. The context is honored during a fetch — callers with a tight
// deadline (e.g. `jc doctor --probe-timeout 100ms`) get a clean
// context-error return instead of waiting through the http.Client's
// 30s default timeout. KLA-448 closed the context-leak that forced
// jc doctor to wrap its probe in a goroutine.
func (tc *TokenCache) Token(ctx context.Context) (string, error) {
if ctx == nil {
// Defensive — http.NewRequestWithContext panics on nil. Cobra
// invariants give cmd.Context() != nil in production, but tests
// that construct commands without RunE can leave it nil.
ctx = context.Background()
}
tc.mu.Lock()
defer tc.mu.Unlock()

Expand All @@ -69,7 +81,7 @@ func (tc *TokenCache) Token() (string, error) {
}

// Fetch a new token.
token, expiresIn, err := tc.fetchToken()
token, expiresIn, err := tc.fetchToken(ctx)
if err != nil {
return "", err
}
Expand All @@ -86,13 +98,18 @@ func (tc *TokenCache) ExpiresAt() time.Time {
return tc.expiresAt
}

// fetchToken exchanges client credentials for a bearer token.
func (tc *TokenCache) fetchToken() (string, int, error) {
// fetchToken exchanges client credentials for a bearer token. The
// context is propagated to the outbound HTTP request so a caller's
// deadline / cancellation reaches the actual socket — pre-KLA-448
// this used http.NewRequest (no context) and the request would run
// to the http.Client's 30s timeout regardless of what the caller
// asked for.
func (tc *TokenCache) fetchToken(ctx context.Context) (string, int, error) {
data := url.Values{}
data.Set("grant_type", "client_credentials")
data.Set("scope", "api")

req, err := http.NewRequest("POST", oauthTokenURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(ctx, "POST", oauthTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return "", 0, fmt.Errorf("failed to create token request: %w", err)
}
Expand Down
23 changes: 12 additions & 11 deletions internal/api/oauth_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package api

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -52,7 +53,7 @@ func TestTokenCache_FetchToken_Success(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("test-client-id", "test-client-secret")
token, err := tc.Token()
token, err := tc.Token(context.Background())
if err != nil {
t.Fatalf("Token() error: %v", err)
}
Expand Down Expand Up @@ -81,13 +82,13 @@ func TestTokenCache_CachesToken(t *testing.T) {
tc := NewTokenCache("client-id", "client-secret")

// First call should fetch.
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err != nil {
t.Fatalf("first Token() error: %v", err)
}

// Second call should use cache.
token, err := tc.Token()
token, err := tc.Token(context.Background())
if err != nil {
t.Fatalf("second Token() error: %v", err)
}
Expand Down Expand Up @@ -125,7 +126,7 @@ func TestTokenCache_RefreshesExpiredToken(t *testing.T) {
tc := NewTokenCache("client-id", "client-secret")

// First call.
_, _ = tc.Token()
_, _ = tc.Token(context.Background())
if callCount != 1 {
t.Fatalf("expected 1 call after first Token(), got %d", callCount)
}
Expand All @@ -134,7 +135,7 @@ func TestTokenCache_RefreshesExpiredToken(t *testing.T) {
fakeNow = fakeNow.Add(3631 * time.Second)

// Should refresh.
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err != nil {
t.Fatalf("Token() error after expiry: %v", err)
}
Expand All @@ -155,7 +156,7 @@ func TestTokenCache_InvalidCredentials(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("bad-id", "bad-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err == nil {
t.Fatal("expected error for invalid credentials")
}
Expand All @@ -176,7 +177,7 @@ func TestTokenCache_ForbiddenCredentials(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("client-id", "client-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err == nil {
t.Fatal("expected error for forbidden credentials")
}
Expand All @@ -197,7 +198,7 @@ func TestTokenCache_ServerError(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("client-id", "client-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err == nil {
t.Fatal("expected error for server error")
}
Expand All @@ -222,7 +223,7 @@ func TestTokenCache_EmptyAccessToken(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("client-id", "client-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err == nil {
t.Fatal("expected error for empty access token")
}
Expand Down Expand Up @@ -252,7 +253,7 @@ func TestTokenCache_DefaultExpiresIn(t *testing.T) {
defer func() { nowFunc = origNow }()

tc := NewTokenCache("client-id", "client-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err != nil {
t.Fatalf("Token() error: %v", err)
}
Expand All @@ -278,7 +279,7 @@ func TestTokenCache_ConnectionError(t *testing.T) {
defer func() { oauthTokenURL = orig }()

tc := NewTokenCache("client-id", "client-secret")
_, err := tc.Token()
_, err := tc.Token(context.Background())
if err == nil {
t.Fatal("expected error for connection failure")
}
Expand Down
Loading
Loading