Skip to content

Commit b0c7632

Browse files
appleboyclaude
andcommitted
fix(http): detect oversized responses instead of silent truncation
- Add readResponseBody helper with explicit size limit detection - Replace 5 inline io.LimitReader calls with the shared helper - Return errResponseTooLarge for responses exceeding 1MB - Add unit tests for boundary, oversized, small, and empty responses - Add end-to-end test for oversized response in requestDeviceCode Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dbdd3a0 commit b0c7632

2 files changed

Lines changed: 102 additions & 5 deletions

File tree

main.go

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,22 @@ const defaultKeyringService = "authgate-device-cli"
4646
// maxResponseBodySize limits HTTP response body reads to prevent memory exhaustion (DoS).
4747
const maxResponseBodySize = 1 << 20 // 1 MB
4848

49+
// errResponseTooLarge indicates the server returned an oversized response body.
50+
var errResponseTooLarge = errors.New("response body exceeds maximum allowed size")
51+
52+
// readResponseBody reads the response body up to maxResponseBodySize.
53+
// Returns errResponseTooLarge if the body exceeds the limit.
54+
func readResponseBody(body io.Reader) ([]byte, error) {
55+
data, err := io.ReadAll(io.LimitReader(body, maxResponseBodySize+1))
56+
if err != nil {
57+
return nil, fmt.Errorf("failed to read response: %w", err)
58+
}
59+
if int64(len(data)) > maxResponseBodySize {
60+
return nil, errResponseTooLarge
61+
}
62+
return data, nil
63+
}
64+
4965
// Timeout configuration for different operations
5066
const (
5167
deviceCodeRequestTimeout = 10 * time.Second
@@ -442,7 +458,7 @@ func requestDeviceCode(ctx context.Context) (*oauth2.DeviceAuthResponse, error)
442458
}
443459
defer resp.Body.Close()
444460

445-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
461+
body, err := readResponseBody(resp.Body)
446462
if err != nil {
447463
return nil, fmt.Errorf("failed to read response: %w", err)
448464
}
@@ -642,7 +658,7 @@ func exchangeDeviceCode(
642658
}
643659
defer resp.Body.Close()
644660

645-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
661+
body, err := readResponseBody(resp.Body)
646662
if err != nil {
647663
return nil, fmt.Errorf("failed to read response: %w", err)
648664
}
@@ -700,7 +716,7 @@ func verifyToken(ctx context.Context, accessToken string, d tui.Displayer) error
700716
}
701717
defer resp.Body.Close()
702718

703-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
719+
body, err := readResponseBody(resp.Body)
704720
if err != nil {
705721
return fmt.Errorf("failed to read response: %w", err)
706722
}
@@ -750,7 +766,7 @@ func refreshAccessToken(
750766
}
751767
defer resp.Body.Close()
752768

753-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
769+
body, err := readResponseBody(resp.Body)
754770
if err != nil {
755771
return credstore.Token{}, fmt.Errorf("failed to read response: %w", err)
756772
}
@@ -875,7 +891,7 @@ func makeAPICallWithAutoRefresh(
875891
defer resp.Body.Close()
876892
}
877893

878-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBodySize))
894+
body, err := readResponseBody(resp.Body)
879895
if err != nil {
880896
return fmt.Errorf("failed to read response: %w", err)
881897
}

main_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package main
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
7+
"errors"
68
"fmt"
79
"net/http"
810
"net/http/httptest"
911
"os"
1012
"path/filepath"
13+
"strings"
1114
"sync"
1215
"sync/atomic"
1316
"testing"
@@ -593,3 +596,81 @@ func TestRequestDeviceCode_WithRetry(t *testing.T) {
593596
t.Errorf("Expected 2 attempts (1 retry), got %d", finalCount)
594597
}
595598
}
599+
600+
func TestReadResponseBody_ExactlyAtLimit(t *testing.T) {
601+
data := make([]byte, maxResponseBodySize)
602+
body, err := readResponseBody(bytes.NewReader(data))
603+
if err != nil {
604+
t.Fatalf("unexpected error: %v", err)
605+
}
606+
if len(body) != int(maxResponseBodySize) {
607+
t.Errorf("expected %d bytes, got %d", maxResponseBodySize, len(body))
608+
}
609+
}
610+
611+
func TestReadResponseBody_ExceedsLimit(t *testing.T) {
612+
data := make([]byte, maxResponseBodySize+1)
613+
_, err := readResponseBody(bytes.NewReader(data))
614+
if !errors.Is(err, errResponseTooLarge) {
615+
t.Errorf("expected errResponseTooLarge, got %v", err)
616+
}
617+
}
618+
619+
func TestReadResponseBody_SmallBody(t *testing.T) {
620+
expected := "hello world"
621+
body, err := readResponseBody(strings.NewReader(expected))
622+
if err != nil {
623+
t.Fatalf("unexpected error: %v", err)
624+
}
625+
if string(body) != expected {
626+
t.Errorf("expected %q, got %q", expected, string(body))
627+
}
628+
}
629+
630+
func TestReadResponseBody_EmptyBody(t *testing.T) {
631+
body, err := readResponseBody(strings.NewReader(""))
632+
if err != nil {
633+
t.Fatalf("unexpected error: %v", err)
634+
}
635+
if len(body) != 0 {
636+
t.Errorf("expected empty body, got %d bytes", len(body))
637+
}
638+
}
639+
640+
func TestRequestDeviceCode_OversizedResponse(t *testing.T) {
641+
// Server that returns a response larger than maxResponseBodySize
642+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
643+
w.Header().Set("Content-Type", "application/json")
644+
w.WriteHeader(http.StatusOK)
645+
// Write more than maxResponseBodySize
646+
data := make([]byte, maxResponseBodySize+100)
647+
for i := range data {
648+
data[i] = 'a'
649+
}
650+
_, _ = w.Write(data)
651+
}))
652+
defer server.Close()
653+
654+
oldServerURL := serverURL
655+
serverURL = server.URL
656+
defer func() { serverURL = oldServerURL }()
657+
658+
oldClient := retryClient
659+
newClient, err := retry.NewBackgroundClient(
660+
retry.WithHTTPClient(server.Client()),
661+
)
662+
if err != nil {
663+
t.Fatalf("failed to create retry client: %v", err)
664+
}
665+
retryClient = newClient
666+
defer func() { retryClient = oldClient }()
667+
668+
ctx := context.Background()
669+
_, err = requestDeviceCode(ctx)
670+
if err == nil {
671+
t.Fatal("expected error for oversized response, got nil")
672+
}
673+
if !errors.Is(err, errResponseTooLarge) {
674+
t.Errorf("expected errResponseTooLarge in error chain, got: %v", err)
675+
}
676+
}

0 commit comments

Comments
 (0)