diff --git a/pkg/authz/response_filter.go b/pkg/authz/response_filter.go index e00dceadf5..02ff1d76d9 100644 --- a/pkg/authz/response_filter.go +++ b/pkg/authz/response_filter.go @@ -187,35 +187,44 @@ func (rfw *ResponseFilteringWriter) processSSEResponse(rawResponse []byte) error var written bool if data, ok := bytes.CutPrefix(line, []byte("data:")); ok { message, err := jsonrpc2.DecodeMessage(data) - if err != nil { - rfw.ResponseWriter.WriteHeader(rfw.statusCode) - _, err := rfw.ResponseWriter.Write(rawResponse) - return err - } - - response, ok := message.(*jsonrpc2.Response) - if !ok { - rfw.ResponseWriter.WriteHeader(rfw.statusCode) - _, err := rfw.ResponseWriter.Write(rawResponse) - return err - } - - filteredResponse, err := rfw.filterListResponse(response) - if err != nil { - return rfw.writeErrorResponse(response.ID, err) - } - - filteredData, err := jsonrpc2.EncodeMessage(filteredResponse) - if err != nil { - return rfw.writeErrorResponse(response.ID, err) + switch { + case err != nil: + // Pass this line through unfiltered. Earlier revisions wrote + // rawResponse and returned here, which leaked every subsequent + // data line on the stream past the filter (issue #5257). The + // WARN fires for every filtered method (tools/list, + // prompts/list, resources/list, find_tool) because the bypass + // applies equally to all of them. + slog.Warn("SSE data line could not be decoded as JSON-RPC; passing through unfiltered", + "method", rfw.method, "error", err) + default: + if response, ok := message.(*jsonrpc2.Response); ok { + filteredResponse, err := rfw.filterListResponse(response) + if err != nil { + return rfw.writeErrorResponse(response.ID, err) + } + + filteredData, err := jsonrpc2.EncodeMessage(filteredResponse) + if err != nil { + return rfw.writeErrorResponse(response.ID, err) + } + + _, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n")) + if err != nil { + return fmt.Errorf("%w: %w", errBug, err) + } + + written = true + } else { + // Non-Response message (e.g. a notifications/* frame + // interleaved on the stream). Pass through unfiltered for + // this line only; the next data line may still be the real + // response and must reach the filter. Logs at WARN for + // every filtered method, not just tools/list. + slog.Warn("SSE data line was not a JSON-RPC Response; passing through unfiltered", + "method", rfw.method) + } } - - _, err = rfw.ResponseWriter.Write([]byte("data: " + string(filteredData) + "\n")) - if err != nil { - return fmt.Errorf("%w: %w", errBug, err) - } - - written = true } if !written { diff --git a/pkg/authz/response_filter_test.go b/pkg/authz/response_filter_test.go index f36c70e772..d70986ddb8 100644 --- a/pkg/authz/response_filter_test.go +++ b/pkg/authz/response_filter_test.go @@ -863,3 +863,200 @@ func TestOptimizerPassThroughToolsInResponseFilter(t *testing.T) { "admin_tool has no permit policy and is not a pass-through tool") }) } + +// TestResponseFilteringWriter_SSE_PerLineFallthrough is a regression test for +// issue #5257: when an SSE upstream interleaves a non-Response data line (e.g. +// an MCP notification) or an undecodable data line with a real list response, +// the filter previously wrote the entire raw upstream payload and returned, +// leaking the unfiltered list past Cedar. It must instead pass only the +// offending line through and continue filtering the rest of the stream. +// +// The same code path runs for every method covered by +// requiresResponseFiltering, so each of tools/list, prompts/list, and +// resources/list is exercised below. +func TestResponseFilteringWriter_SSE_PerLineFallthrough(t *testing.T) { + t.Parallel() + + authorizer, err := cedar.NewCedarAuthorizer(cedar.ConfigOptions{ + Policies: []string{ + `permit(principal, action == Action::"call_tool", resource == Tool::"weather");`, + `permit(principal, action == Action::"get_prompt", resource == Prompt::"greeting");`, + `permit(principal, action == Action::"read_resource", resource == Resource::"data");`, + }, + EntitiesJSON: `[]`, + }, "") + require.NoError(t, err) + + identity := &auth.Identity{PrincipalInfo: auth.PrincipalInfo{ + Subject: "user1", + Claims: map[string]interface{}{"sub": "user1"}, + }} + + // encodeListResponse marshals a list result type into a JSON-RPC Response + // data line. + encodeListResponse := func(t *testing.T, result interface{}) string { + t.Helper() + resultJSON, err := json.Marshal(result) + require.NoError(t, err) + encoded, err := jsonrpc2.EncodeMessage(&jsonrpc2.Response{ + ID: jsonrpc2.Int64ID(1), + Result: json.RawMessage(resultJSON), + }) + require.NoError(t, err) + return "data: " + string(encoded) + } + + // methodCase describes how to build a filterable response for one MCP + // list method and how to read the filtered names out of the wire output. + type methodCase struct { + name string + method string + respLine string + authorizedName string + unauthorizedName string + extractNames func(t *testing.T, result json.RawMessage) []string + } + + methodCases := []methodCase{ + { + name: "tools/list", + method: string(mcp.MethodToolsList), + respLine: encodeListResponse(t, mcp.ListToolsResult{ + Tools: []mcp.Tool{ + {Name: "weather", Description: "Get weather information"}, + {Name: "admin_tool", Description: "Sensitive admin operations"}, + }, + }), + authorizedName: "weather", + unauthorizedName: "admin_tool", + extractNames: func(t *testing.T, result json.RawMessage) []string { + t.Helper() + var r mcp.ListToolsResult + require.NoError(t, json.Unmarshal(result, &r)) + names := make([]string, len(r.Tools)) + for i, tool := range r.Tools { + names[i] = tool.Name + } + return names + }, + }, + { + name: "prompts/list", + method: string(mcp.MethodPromptsList), + respLine: encodeListResponse(t, mcp.ListPromptsResult{ + Prompts: []mcp.Prompt{ + {Name: "greeting", Description: "Generate greetings"}, + {Name: "admin_prompt", Description: "Sensitive admin prompt"}, + }, + }), + authorizedName: "greeting", + unauthorizedName: "admin_prompt", + extractNames: func(t *testing.T, result json.RawMessage) []string { + t.Helper() + var r mcp.ListPromptsResult + require.NoError(t, json.Unmarshal(result, &r)) + names := make([]string, len(r.Prompts)) + for i, p := range r.Prompts { + names[i] = p.Name + } + return names + }, + }, + { + name: "resources/list", + method: string(mcp.MethodResourcesList), + respLine: encodeListResponse(t, mcp.ListResourcesResult{ + Resources: []mcp.Resource{ + {URI: "data", Name: "Data Resource"}, + {URI: "secret", Name: "Sensitive Resource"}, + }, + }), + authorizedName: "data", + unauthorizedName: "secret", + extractNames: func(t *testing.T, result json.RawMessage) []string { + t.Helper() + var r mcp.ListResourcesResult + require.NoError(t, json.Unmarshal(result, &r)) + names := make([]string, len(r.Resources)) + for i, res := range r.Resources { + names[i] = res.URI + } + return names + }, + }, + } + + precedingLineCases := []struct { + name string + line string + }{ + { + name: "non-response data line", + // A notifications/* frame is a valid JSON-RPC notification + // (no id), so jsonrpc2.DecodeMessage returns a non-Response + // message. The buggy path treated this as a signal to dump + // rawResponse and return. + line: `data: {"jsonrpc":"2.0","method":"notifications/message","params":{"level":"info","data":"warming up"}}`, + }, + { + name: "undecodable data line", + line: `data: this is not json at all`, + }, + } + + for _, mc := range methodCases { + for _, plc := range precedingLineCases { + mc, plc := mc, plc + t.Run(mc.name+"/"+plc.name, func(t *testing.T) { + t.Parallel() + + req, err := http.NewRequest(http.MethodPost, "/messages", nil) + require.NoError(t, err) + req = req.WithContext(auth.WithIdentity(req.Context(), identity)) + + rr := httptest.NewRecorder() + rfw := NewResponseFilteringWriter(rr, authorizer, req, mc.method, nil, nil) + rfw.ResponseWriter.Header().Set("Content-Type", "text/event-stream") + + body := strings.Join([]string{plc.line, mc.respLine, ""}, "\n") + _, err = rfw.Write([]byte(body)) + require.NoError(t, err) + + require.NoError(t, rfw.FlushAndFilter()) + + out := rr.Body.String() + + // The preceding line must still appear verbatim; pass-through + // is the whole point of the fix. + assert.Contains(t, out, plc.line, + "non-response/undecodable preceding line must pass through unchanged") + + // The real list response must have been filtered. Pull the + // last JSON-RPC Response data line out and decode it. + var filteredLine string + for _, line := range strings.Split(out, "\n") { + if strings.HasPrefix(line, "data: {\"jsonrpc\"") && strings.Contains(line, `"result"`) { + filteredLine = line + } + } + require.NotEmpty(t, filteredLine, "no JSON-RPC Response data line found in output") + + payload := strings.TrimPrefix(filteredLine, "data: ") + msg, err := jsonrpc2.DecodeMessage([]byte(payload)) + require.NoError(t, err) + resp, ok := msg.(*jsonrpc2.Response) + require.True(t, ok) + + names := mc.extractNames(t, resp.Result) + assert.Contains(t, names, mc.authorizedName, "authorized entry must be retained") + assert.NotContains(t, names, mc.unauthorizedName, + "unauthorized entry must be filtered; presence indicates the cedar bypass from #5257 is back") + + // And the raw unfiltered payload (the bug used to dump it) + // must not appear in the wire output. + assert.NotContains(t, out, `"`+mc.unauthorizedName+`"`, + "unfiltered list payload leaked into SSE output") + }) + } + } +}