diff --git a/internal/plugins/backends/ollama/plugin.go b/internal/plugins/backends/ollama/plugin.go index 6633c02b..585aebb3 100644 --- a/internal/plugins/backends/ollama/plugin.go +++ b/internal/plugins/backends/ollama/plugin.go @@ -138,7 +138,7 @@ func fetchVersion(ctx context.Context, client *http.Client, nativeRoot string) ( if resp.StatusCode < 200 || resp.StatusCode > 299 { return "", fmt.Errorf("version HTTP status %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) if err != nil { return "", err } diff --git a/internal/plugins/backends/ollama/plugin_test.go b/internal/plugins/backends/ollama/plugin_test.go index 731ad8ae..eba4fd77 100644 --- a/internal/plugins/backends/ollama/plugin_test.go +++ b/internal/plugins/backends/ollama/plugin_test.go @@ -138,3 +138,22 @@ func TestNew_VersionDetectionAndResponses(t *testing.T) { t.Fatal("expected explicitly disabled responses to not be supported") } } + +func TestFetchVersion_OOM_Regression(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + chunk := make([]byte, 1024*1024) + for i := 0; i < 50; i++ { + w.Write(chunk) + } + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := fetchVersion(ctx, srv.Client(), srv.URL) + if err == nil { + t.Fatalf("expected error due to large payload or invalid json, got none") + } +} diff --git a/internal/plugins/backends/opencodecommon/catalog/loader.go b/internal/plugins/backends/opencodecommon/catalog/loader.go index aff49042..6c8b9ee8 100644 --- a/internal/plugins/backends/opencodecommon/catalog/loader.go +++ b/internal/plugins/backends/opencodecommon/catalog/loader.go @@ -132,7 +132,7 @@ func getJSON(ctx context.Context, client *http.Client, endpoint string, headers body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096)) return nil, fmt.Errorf("opencodecommon: model discovery HTTP status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) } - body, err := io.ReadAll(resp.Body) + body, err := io.ReadAll(io.LimitReader(resp.Body, 10*1024*1024)) if err != nil { return nil, fmt.Errorf("opencodecommon: model discovery read: %w", err) } diff --git a/internal/plugins/backends/opencodecommon/catalog/loader_test.go b/internal/plugins/backends/opencodecommon/catalog/loader_test.go index 00e94bed..6850cdde 100644 --- a/internal/plugins/backends/opencodecommon/catalog/loader_test.go +++ b/internal/plugins/backends/opencodecommon/catalog/loader_test.go @@ -115,3 +115,22 @@ func TestLoadModelEntries_noFallbackOnRemoteFailure(t *testing.T) { t.Fatal("expected error without fallback") } } + +func TestGetJSON_OOM_Regression(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + chunk := make([]byte, 1024*1024) + for i := 0; i < 50; i++ { + w.Write(chunk) + } + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err := getJSON(ctx, srv.Client(), srv.URL, nil) + if err == nil { + t.Fatalf("expected error due to large payload or invalid json, got none") + } +}