Skip to content

Commit ddd4bd7

Browse files
appleboyclaude
andauthored
fix: bound HTTP response reads and use sync.Once for config init (#16)
* fix(main): bound HTTP response reads and use sync.Once for config init - Limit all HTTP response body reads to 1 MB using io.LimitReader to prevent memory exhaustion - Replace non-thread-safe configInitialized bool with sync.Once for safe concurrent initialization Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(http): detect oversized responses instead of silent truncation - Add readResponseBody helper with explicit size limit detection - Replace 5 inline io.LimitReader calls with the shared helper - Return errResponseTooLarge for responses exceeding 1MB - Add unit tests for boundary, oversized, small, and empty responses - Add end-to-end test for oversized response in requestDeviceCode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent b1def4f commit ddd4bd7

2 files changed

Lines changed: 120 additions & 20 deletions

File tree

main.go

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,40 @@ import (
2828
)
2929

3030
var (
31-
serverURL string
32-
clientID string
33-
tokenFile string
34-
tokenStoreMode string
35-
flagServerURL *string
36-
flagClientID *string
37-
flagTokenFile *string
38-
flagTokenStore *string
39-
configInitialized bool
40-
retryClient *retry.Client
41-
tokenStore credstore.Store[credstore.Token]
31+
serverURL string
32+
clientID string
33+
tokenFile string
34+
tokenStoreMode string
35+
flagServerURL *string
36+
flagClientID *string
37+
flagTokenFile *string
38+
flagTokenStore *string
39+
configOnce sync.Once
40+
retryClient *retry.Client
41+
tokenStore credstore.Store[credstore.Token]
4242
)
4343

4444
const defaultKeyringService = "authgate-device-cli"
4545

46+
// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS).
47+
const maxResponseBodySize = 1 << 20 // 1 MB
48+
49+
// errResponseTooLarge indicates the server returned an oversized response body.
50+
var errResponseTooLarge = errors.New("response body exceeds maximum allowed size")
51+
52+
// readResponseBody reads the response body up to maxResponseBodySize.
53+
// Returns errResponseTooLarge if the body exceeds the limit.
54+
func readResponseBody(body io.Reader) ([]byte, error) {
55+
data, err := io.ReadAll(io.LimitReader(body, maxResponseBodySize+1))
56+
if err != nil {
57+
return nil, fmt.Errorf("failed to read response: %w", err)
58+
}
59+
if int64(len(data)) > maxResponseBodySize {
60+
return nil, errResponseTooLarge
61+
}
62+
return data, nil
63+
}
64+
4665
// Timeout configuration for different operations
4766
const (
4867
deviceCodeRequestTimeout = 10 * time.Second
@@ -107,11 +126,12 @@ func init() {
107126
// initConfig parses flags and initializes configuration
108127
// Separated from init() to avoid conflicts with test flag parsing
109128
func initConfig() {
110-
if configInitialized {
111-
return
112-
}
113-
configInitialized = true
129+
configOnce.Do(func() {
130+
doInitConfig()
131+
})
132+
}
114133

134+
func doInitConfig() {
115135
flag.Parse()
116136

117137
// Priority: flag > env > default
@@ -438,7 +458,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
438458
}
439459
defer resp.Body.Close()
440460

441-
body, err := io.ReadAll(resp.Body)
461+
body, err := readResponseBody(resp.Body)
442462
if err != nil {
443463
return nil, fmt.Errorf("failed to read response: %w", err)
444464
}
@@ -633,7 +653,7 @@ func exchangeDeviceCode(
633653
}
634654
defer resp.Body.Close()
635655

636-
body, err := io.ReadAll(resp.Body)
656+
body, err := readResponseBody(resp.Body)
637657
if err != nil {
638658
return nil, fmt.Errorf("failed to read response: %w", err)
639659
}
@@ -691,7 +711,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error
691711
}
692712
defer resp.Body.Close()
693713

694-
body, err := io.ReadAll(resp.Body)
714+
body, err := readResponseBody(resp.Body)
695715
if err != nil {
696716
return fmt.Errorf("failed to read response: %w", err)
697717
}
@@ -741,7 +761,7 @@ func refreshAccessToken(
741761
}
742762
defer resp.Body.Close()
743763

744-
body, err := io.ReadAll(resp.Body)
764+
body, err := readResponseBody(resp.Body)
745765
if err != nil {
746766
return credstore.Token{}, fmt.Errorf("failed to read response: %w", err)
747767
}
@@ -869,7 +889,7 @@ func makeAPICallWithAutoRefresh(
869889
defer resp.Body.Close()
870890
}
871891

872-
body, err := io.ReadAll(resp.Body)
892+
body, err := readResponseBody(resp.Body)
873893
if err != nil {
874894
return fmt.Errorf("failed to read response: %w", err)
875895
}

main_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package main
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
7+
"errors"
68
"fmt"
79
"net/http"
810
"net/http/httptest"
@@ -579,3 +581,81 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) {
579581
t.Errorf("Expected 2 attempts (1 retry), got %d", finalCount)
580582
}
581583
}
584+
585+
func TestReadResponseBody_ExactlyAtLimit(t *testing.T) {
586+
data := make([]byte, maxResponseBodySize)
587+
body, err := readResponseBody(bytes.NewReader(data))
588+
if err != nil {
589+
t.Fatalf("unexpected error: %v", err)
590+
}
591+
if len(body) != int(maxResponseBodySize) {
592+
t.Errorf("expected %d bytes, got %d", maxResponseBodySize, len(body))
593+
}
594+
}
595+
596+
func TestReadResponseBody_ExceedsLimit(t *testing.T) {
597+
data := make([]byte, maxResponseBodySize+1)
598+
_, err := readResponseBody(bytes.NewReader(data))
599+
if !errors.Is(err, errResponseTooLarge) {
600+
t.Errorf("expected errResponseTooLarge, got %v", err)
601+
}
602+
}
603+
604+
func TestReadResponseBody_SmallBody(t *testing.T) {
605+
expected := "hello world"
606+
body, err := readResponseBody(strings.NewReader(expected))
607+
if err != nil {
608+
t.Fatalf("unexpected error: %v", err)
609+
}
610+
if string(body) != expected {
611+
t.Errorf("expected %q, got %q", expected, string(body))
612+
}
613+
}
614+
615+
func TestReadResponseBody_EmptyBody(t *testing.T) {
616+
body, err := readResponseBody(strings.NewReader(""))
617+
if err != nil {
618+
t.Fatalf("unexpected error: %v", err)
619+
}
620+
if len(body) != 0 {
621+
t.Errorf("expected empty body, got %d bytes", len(body))
622+
}
623+
}
624+
625+
func TestRequestDeviceCode_OversizedResponse(t *testing.T) {
626+
// Server that returns a response larger than maxResponseBodySize
627+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
628+
w.Header().Set("Content-Type", "application/json")
629+
w.WriteHeader(http.StatusOK)
630+
// Write more than maxResponseBodySize
631+
data := make([]byte, maxResponseBodySize+100)
632+
for i := range data {
633+
data[i] = 'a'
634+
}
635+
_, _ = w.Write(data)
636+
}))
637+
defer server.Close()
638+
639+
oldServerURL := serverURL
640+
serverURL = server.URL
641+
defer func() { serverURL = oldServerURL }()
642+
643+
oldClient := retryClient
644+
newClient, err := retry.NewBackgroundClient(
645+
retry.WithHTTPClient(server.Client()),
646+
)
647+
if err != nil {
648+
t.Fatalf("failed to create retry client: %v", err)
649+
}
650+
retryClient = newClient
651+
defer func() { retryClient = oldClient }()
652+
653+
ctx := context.Background()
654+
_, err = requestDeviceCode(ctx)
655+
if err == nil {
656+
t.Fatal("expected error for oversized response, got nil")
657+
}
658+
if !errors.Is(err, errResponseTooLarge) {
659+
t.Errorf("expected errResponseTooLarge in error chain, got: %v", err)
660+
}
661+
}

0 commit comments

Comments
 (0)