From af548174838aa4d7b377421d69d8bcd09e8ec097 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Fri, 1 May 2026 12:36:05 +0000 Subject: [PATCH 01/15] test(proxy): integration tests for session correlation audit and header agreement Add integration tests that verify the core invariants of session correlation across the proxy, auditor, and forwarded request headers working together. These tests fill the gap identified during review of the session correlation PR stack (#196, #197, #198) where unit tests verified each component in isolation but did not verify them in concert. New test file: proxy/proxy_session_correlation_integration_test.go Tests added: - LLMRequestAuditAndHeadersAgree: audit sequence number matches the forwarded header value on inject-target requests. - NonLLMRequestAuditedWithoutHeaders: allowed non-inject-target requests are audited but carry no correlation headers. - DeniedRequestAuditedNeverForwarded: denied requests consume a sequence number but are never forwarded. - MixedRequestsSequenceOrdering: interleaved LLM, non-LLM, and denied requests all advance the counter monotonically. - SequenceGapRevealsAgenticLoop: gap between two LLM sequence numbers precisely equals intermediate tool-use requests. - SpoofedHeadersOverwrittenWithCorrectSequence: client-supplied headers are replaced and the audit event still agrees. - DisabledCorrelationNoHeadersNoPreallocatedSequence: disabled correlation means no headers and no pre-allocated sequence. - ConcurrentRequestsUniqueSequenceNumbers: concurrent requests each get a unique, dense sequence number. --- ...xy_session_correlation_integration_test.go | 578 ++++++++++++++++++ 1 file changed, 578 insertions(+) create mode 100644 proxy/proxy_session_correlation_integration_test.go diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go new file mode 100644 index 0000000..b8035b7 --- /dev/null +++ b/proxy/proxy_session_correlation_integration_test.go @@ -0,0 +1,578 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "sync" + "testing" + + "github.com/coder/boundary/audit" + "github.com/coder/boundary/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// multiRequestCapturingBackend records the headers from every request it +// receives, not just the last one. This is needed by integration tests +// that send multiple requests to the same backend and want to verify +// each one independently. +type multiRequestCapturingBackend struct { + server *httptest.Server + mu sync.Mutex + all []http.Header +} + +func newMultiRequestCapturingBackend() *multiRequestCapturingBackend { + mcb := &multiRequestCapturingBackend{} + mcb.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mcb.mu.Lock() + mcb.all = append(mcb.all, r.Header.Clone()) + mcb.mu.Unlock() + w.WriteHeader(http.StatusOK) + })) + return mcb +} + +func (m *multiRequestCapturingBackend) close() { m.server.Close() } + +func (m *multiRequestCapturingBackend) requestCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.all) +} + +func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { + m.mu.Lock() + defer m.mu.Unlock() + return m.all[i].Clone() +} + +// sessionCorrelationIntegrationSetup holds the shared objects for an +// integration test: the proxy, auditor, backend(s), and sequence +// counter. Tests build one via newSessionCorrelationIntegrationSetup +// and tear it down with stop. +type sessionCorrelationIntegrationSetup struct { + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + llmBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend +} + +func (s *sessionCorrelationIntegrationSetup) stop() { + s.pt.Stop() + if s.llmBackend != nil { + s.llmBackend.close() + } + if s.otherBackend != nil { + s.otherBackend.close() + } +} + +// newSessionCorrelationIntegrationSetup builds a proxy that allows +// traffic to two httptest backends: one that matches an inject target +// (simulating an LLM provider) and one that does not (simulating a +// generic allowed domain like github.com). Both backends capture all +// received request headers. A capturingAuditor records every audit +// event for later inspection. +func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { + t.Helper() + + llm := newMultiRequestCapturingBackend() + other := newMultiRequestCapturingBackend() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + // Both httptest backends resolve to 127.0.0.1, so a domain-only + // inject target would match both. We use a path glob on the LLM + // paths (/v1/*) to limit header injection to LLM requests. + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // Allow both backends. + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only requests matching the LLM path receive headers. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + + return &sessionCorrelationIntegrationSetup{ + pt: pt, + auditor: aud, + seq: seq, + llmBackend: llm, + otherBackend: other, + } +} + +// ---------- Integration Tests ---------- + +// TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core +// correlation invariant: when an allowed request hits an inject target, +// the sequence number in the audit event equals the sequence number in +// the forwarded header. +func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Forwarded headers. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // The two must agree. + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + "audit event and forwarded header must carry the same sequence number", + ) +} + +// TestIntegration_NonLLMRequestAuditedWithoutHeaders verifies that an +// allowed request to a domain that is NOT an inject target still gets +// audited (with a sequence number) but does NOT receive correlation +// headers. +func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + s := newSessionCorrelationIntegrationSetup(t, "test-session") + defer s.stop() + + resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Audit event recorded. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.True(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // No correlation headers on the backend. + require.Equal(t, 1, s.otherBackend.requestCount()) + hdr := s.otherBackend.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + "non-inject-target requests must not carry session ID header") + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + "non-inject-target requests must not carry sequence number header") +} + +// TestIntegration_DeniedRequestAuditedNeverForwarded verifies that a +// request denied by the rules engine is audited (consuming a sequence +// number) but is never forwarded to any backend. +func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { + // Create a setup with a custom deny-all proxy, but keep the same + // pattern of shared sequence counter and auditor. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + // No allowed domains: deny everything. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("test-session"), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Audit event recorded. + events := aud.getRequests() + require.Len(t, events, 1) + require.False(t, events[0].Allowed) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, uint64(0), *events[0].SequenceNumber) + + // Backend never hit. + assert.Equal(t, 0, llm.requestCount(), + "denied requests must not be forwarded to the backend") +} + +// TestIntegration_MixedRequestsSequenceOrdering sends a realistic +// sequence of LLM, non-LLM, and denied requests, then verifies: +// 1. Sequence numbers increase monotonically across all request types. +// 2. Only inject-target requests carry correlation headers. +// 3. The sequence numbers in headers match the audit events. +// 4. The gap between two LLM requests' sequence numbers reveals the +// intermediate non-LLM and denied activity. +func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { + const sessionID = "mixed-test-session" + + // Two allowed backends (LLM and "github"), one denied domain. + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + // Only LLM is an inject target. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // Request 0: LLM (allowed, inject target). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 1: non-LLM (allowed, no inject). + resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: denied (nothing is allowed for evil.example.com). + resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusForbidden, resp.StatusCode) + + // Request 3: LLM again. + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // -- Verify audit events -- + events := aud.getRequests() + require.Len(t, events, 4, "expected exactly four audit events") + + expectedSeq := []uint64{0, 1, 2, 3} + expectedAllowed := []bool{true, true, false, true} + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) + assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + "event %d: wrong sequence number", i) + assert.Equal(t, expectedAllowed[i], ev.Allowed, + "event %d: wrong allowed flag", i) + } + + // -- Verify LLM backend headers -- + require.Equal(t, 2, llm.requestCount(), + "LLM backend should have received exactly two requests") + + firstLLMHdr := llm.headersAt(0) + assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "first LLM request must have sequence 0") + + secondLLMHdr := llm.headersAt(1) + assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + "second LLM request must have sequence 3") + + // -- Verify non-LLM backend has no correlation headers -- + require.Equal(t, 1, other.requestCount()) + otherHdr := other.headersAt(0) + assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + + // -- Verify the gap reveals intermediate activity -- + // The gap between the two LLM sequence numbers (0 and 3) means + // that sequence numbers 1 and 2 were consumed by non-LLM + // activity, matching audit events 1 (non-LLM allowed) and 2 + // (denied). + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[3].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(2), gap, + "gap between LLM requests should reveal 2 intermediate events") +} + +// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests +// with several non-LLM requests in between, simulating an agentic loop +// where the model triggers tool-use HTTP calls between prompts. The +// test verifies that the gap in LLM sequence numbers precisely +// reflects the count of intermediate boundary events. +func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { + const sessionID = "agentic-loop-session" + + llm := newMultiRequestCapturingBackend() + defer llm.close() + + other := newMultiRequestCapturingBackend() + defer other.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + otherURL, err := url.Parse(other.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + seq := &audit.SequenceCounter{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(otherURL.Hostname()), + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: true, + InjectTargets: []config.InjectTarget{{ + Domain: llmURL.Hostname(), + Path: "/v1/*", + }}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID(sessionID), + WithSequenceCounter(seq), + WithAuditor(aud), + ).Start() + defer pt.Stop() + + // First LLM prompt (seq 0). + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Agentic loop: three tool-use HTTP calls. + for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { + resp, err = pt.proxyClient.Get(other.server.URL + p) + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + } + + // Second LLM prompt (seq 4). + resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + resp.Body.Close() //nolint:errcheck + + // Verify LLM sequence headers. + require.Equal(t, 2, llm.requestCount()) + assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + + // The gap between sequence numbers 0 and 4 is 3, matching the + // three tool-use requests in between. + events := aud.getRequests() + require.Len(t, events, 5) + + firstLLMSeq := *events[0].SequenceNumber + secondLLMSeq := *events[4].SequenceNumber + gap := secondLLMSeq - firstLLMSeq - 1 + assert.Equal(t, uint64(3), gap, + "gap between prompts should equal number of tool-use requests") + + // Verify the intermediate events are the tool-use requests. + for i := 1; i <= 3; i++ { + require.NotNil(t, events[i].SequenceNumber) + assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.True(t, events[i].Allowed) + } +} + +// TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence +// verifies that when a jailed client sets its own correlation headers, +// the proxy replaces them with the real session ID and the real +// sequence number, and the audit event still agrees with the header. +func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + const sessionID = "real-session-uuid" + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + require.NoError(t, err) + req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") + req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + + resp, err := s.pt.proxyClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Backend received real values, not spoofed. + require.Equal(t, 1, s.llmBackend.requestCount()) + hdr := s.llmBackend.headersAt(0) + assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event agrees with header. + events := s.auditor.getRequests() + require.Len(t, events, 1) + require.NotNil(t, events[0].SequenceNumber) + assert.Equal(t, + strconv.FormatUint(*events[0].SequenceNumber, 10), + hdr.Get(config.DefaultSequenceNumberHeaderName), + ) +} + +// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence +// verifies that when session correlation is disabled, the proxy does +// not inject headers and does not pre-allocate sequence numbers (the +// auditor falls back to its own counter instead). +func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { + llm := newMultiRequestCapturingBackend() + defer llm.close() + + llmURL, err := url.Parse(llm.server.URL) + require.NoError(t, err) + + aud := &capturingAuditor{} + + pt := NewProxyTest(t, + WithCertManager(t.TempDir()), + WithAllowedDomain(llmURL.Hostname()), + // Correlation disabled; no sequence counter. + WithSessionCorrelation(config.SessionCorrelationConfig{ + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + SessionIDHeaderName: config.DefaultSessionIDHeaderName, + SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + }), + WithSessionID("should-not-appear"), + // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. + WithAuditor(aud), + ).Start() + defer pt.Stop() + + resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + require.NoError(t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) + + // No correlation headers. + require.Equal(t, 1, llm.requestCount()) + hdr := llm.headersAt(0) + assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + + // Audit event recorded but without a pre-allocated sequence + // number (nil), because no SequenceCounter was provided. + events := aud.getRequests() + require.Len(t, events, 1) + assert.Nil(t, events[0].SequenceNumber, + "no sequence counter means no pre-allocated sequence number") +} + +// TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends +// multiple requests concurrently and verifies that every request +// receives a unique sequence number, and that the set of numbers is +// dense (no gaps, no duplicates). +func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { + const sessionID = "concurrent-session" + const numRequests = 10 + + s := newSessionCorrelationIntegrationSetup(t, sessionID) + defer s.stop() + + var wg sync.WaitGroup + for i := 0; i < numRequests; i++ { + wg.Add(1) + go func() { + defer wg.Done() + resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + assert.NoError(t, err) + if resp != nil { + resp.Body.Close() //nolint:errcheck + } + }() + } + wg.Wait() + + // Every request should have been audited. + events := s.auditor.getRequests() + require.Len(t, events, numRequests) + + // Collect all sequence numbers and verify uniqueness. + seen := make(map[uint64]bool, numRequests) + for i, ev := range events { + require.NotNil(t, ev.SequenceNumber, + "event %d: sequence number must not be nil", i) + assert.False(t, seen[*ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) + seen[*ev.SequenceNumber] = true + } + + // The set should be exactly {0, 1, ..., numRequests-1}. + for i := uint64(0); i < numRequests; i++ { + assert.True(t, seen[i], + "sequence number %d is missing from the set", i) + } + + // Every header should also carry a matching sequence number. + require.Equal(t, numRequests, s.llmBackend.requestCount()) + headerSeqs := make(map[string]bool, numRequests) + for i := 0; i < numRequests; i++ { + hdr := s.llmBackend.headersAt(i) + seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) + headerSeqs[seqStr] = true + } + for i := uint64(0); i < numRequests; i++ { + assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], + "header sequence number %d is missing", i) + } +} From 5bd67ece35b747664bf57d0e184134dfea19e88f Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 7 May 2026 17:27:55 +0000 Subject: [PATCH 02/15] refactor(proxy): update session correlation tests to use new header names and sequence number type Modified integration tests to reflect changes in session correlation header names and updated the sequence number type from uint64 to int32. Adjusted assertions in tests to ensure consistency with the new data types and header configurations, enhancing clarity and correctness in the test suite. --- ...xy_session_correlation_integration_test.go | 109 ++++++++---------- 1 file changed, 46 insertions(+), 63 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index b8035b7..4000270 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -108,11 +108,8 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() @@ -146,18 +143,18 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // Forwarded headers. require.Equal(t, 1, s.llmBackend.requestCount()) hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) // The two must agree. assert.Equal(t, - strconv.FormatUint(*events[0].SequenceNumber, 10), - hdr.Get(config.DefaultSequenceNumberHeaderName), + strconv.Itoa(int(events[0].SequenceNumber)), + hdr.Get(config.SequenceNumberHeaderName), "audit event and forwarded header must carry the same sequence number", ) } @@ -180,14 +177,14 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // No correlation headers on the backend. require.Equal(t, 1, s.otherBackend.requestCount()) hdr := s.otherBackend.headersAt(0) - assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName), + assert.Empty(t, hdr.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") - assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName), "non-inject-target requests must not carry sequence number header") } @@ -201,19 +198,15 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { defer llm.close() aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: true, + InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, }), WithSessionID("test-session"), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -228,7 +221,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { require.Len(t, events, 1) require.False(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) - assert.Equal(t, uint64(0), *events[0].SequenceNumber) + assert.Equal(t, int32(0), events[0].SequenceNumber) // Backend never hit. assert.Equal(t, 0, llm.requestCount(), @@ -259,7 +252,6 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { require.NoError(t, err) aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), @@ -272,11 +264,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -309,11 +298,11 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { events := aud.getRequests() require.Len(t, events, 4, "expected exactly four audit events") - expectedSeq := []uint64{0, 1, 2, 3} + expectedSeq := []int32{0, 1, 2, 3} expectedAllowed := []bool{true, true, false, true} for i, ev := range events { require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) - assert.Equal(t, expectedSeq[i], *ev.SequenceNumber, + assert.Equal(t, expectedSeq[i], ev.SequenceNumber, "event %d: wrong sequence number", i) assert.Equal(t, expectedAllowed[i], ev.Allowed, "event %d: wrong allowed flag", i) @@ -324,30 +313,30 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "LLM backend should have received exactly two requests") firstLLMHdr := llm.headersAt(0) - assert.Equal(t, sessionID, firstLLMHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", firstLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, sessionID, firstLLMHdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", firstLLMHdr.Get(config.SequenceNumberHeaderName), "first LLM request must have sequence 0") secondLLMHdr := llm.headersAt(1) - assert.Equal(t, sessionID, secondLLMHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "3", secondLLMHdr.Get(config.DefaultSequenceNumberHeaderName), + assert.Equal(t, sessionID, secondLLMHdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "3", secondLLMHdr.Get(config.SequenceNumberHeaderName), "second LLM request must have sequence 3") // -- Verify non-LLM backend has no correlation headers -- require.Equal(t, 1, other.requestCount()) otherHdr := other.headersAt(0) - assert.Empty(t, otherHdr.Get(config.DefaultSessionIDHeaderName)) - assert.Empty(t, otherHdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Empty(t, otherHdr.Get(config.SessionIDHeaderName)) + assert.Empty(t, otherHdr.Get(config.SequenceNumberHeaderName)) // -- Verify the gap reveals intermediate activity -- // The gap between the two LLM sequence numbers (0 and 3) means // that sequence numbers 1 and 2 were consumed by non-LLM // activity, matching audit events 1 (non-LLM allowed) and 2 // (denied). - firstLLMSeq := *events[0].SequenceNumber - secondLLMSeq := *events[3].SequenceNumber + firstLLMSeq := events[0].SequenceNumber + secondLLMSeq := events[3].SequenceNumber gap := secondLLMSeq - firstLLMSeq - 1 - assert.Equal(t, uint64(2), gap, + assert.Equal(t, int32(2), gap, "gap between LLM requests should reveal 2 intermediate events") } @@ -372,7 +361,6 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { require.NoError(t, err) aud := &capturingAuditor{} - seq := &audit.SequenceCounter{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), @@ -384,11 +372,8 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { Domain: llmURL.Hostname(), Path: "/v1/*", }}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, }), WithSessionID(sessionID), - WithSequenceCounter(seq), WithAuditor(aud), ).Start() defer pt.Stop() @@ -412,24 +397,24 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // Verify LLM sequence headers. require.Equal(t, 2, llm.requestCount()) - assert.Equal(t, "0", llm.headersAt(0).Get(config.DefaultSequenceNumberHeaderName)) - assert.Equal(t, "4", llm.headersAt(1).Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, "0", llm.headersAt(0).Get(config.SequenceNumberHeaderName)) + assert.Equal(t, "4", llm.headersAt(1).Get(config.SequenceNumberHeaderName)) // The gap between sequence numbers 0 and 4 is 3, matching the // three tool-use requests in between. events := aud.getRequests() require.Len(t, events, 5) - firstLLMSeq := *events[0].SequenceNumber - secondLLMSeq := *events[4].SequenceNumber + firstLLMSeq := events[0].SequenceNumber + secondLLMSeq := events[4].SequenceNumber gap := secondLLMSeq - firstLLMSeq - 1 - assert.Equal(t, uint64(3), gap, + assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") // Verify the intermediate events are the tool-use requests. for i := 1; i <= 3; i++ { require.NotNil(t, events[i].SequenceNumber) - assert.Equal(t, uint64(i), *events[i].SequenceNumber) + assert.Equal(t, int32(i), events[i].SequenceNumber) assert.True(t, events[i].Allowed) } } @@ -445,8 +430,8 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) - req.Header.Set(config.DefaultSessionIDHeaderName, "spoofed-session") - req.Header.Set(config.DefaultSequenceNumberHeaderName, "9999") + req.Header.Set(config.SessionIDHeaderName, "spoofed-session") + req.Header.Set(config.SequenceNumberHeaderName, "9999") resp, err := s.pt.proxyClient.Do(req) require.NoError(t, err) @@ -456,16 +441,16 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // Backend received real values, not spoofed. require.Equal(t, 1, s.llmBackend.requestCount()) hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) // Audit event agrees with header. events := s.auditor.getRequests() require.Len(t, events, 1) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, - strconv.FormatUint(*events[0].SequenceNumber, 10), - hdr.Get(config.DefaultSequenceNumberHeaderName), + strconv.Itoa(int(events[0].SequenceNumber)), + hdr.Get(config.SequenceNumberHeaderName), ) } @@ -487,10 +472,8 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi WithAllowedDomain(llmURL.Hostname()), // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, - SessionIDHeaderName: config.DefaultSessionIDHeaderName, - SequenceNumberHeaderName: config.DefaultSequenceNumberHeaderName, + Enabled: false, + InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, }), WithSessionID("should-not-appear"), // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. @@ -506,14 +489,14 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi // No correlation headers. require.Equal(t, 1, llm.requestCount()) hdr := llm.headersAt(0) - assert.Empty(t, hdr.Get(config.DefaultSessionIDHeaderName)) - assert.Empty(t, hdr.Get(config.DefaultSequenceNumberHeaderName)) + assert.Empty(t, hdr.Get(config.SessionIDHeaderName)) + assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName)) // Audit event recorded but without a pre-allocated sequence // number (nil), because no SequenceCounter was provided. events := aud.getRequests() require.Len(t, events, 1) - assert.Nil(t, events[0].SequenceNumber, + assert.Equal(t, int32(0), events[0].SequenceNumber, "no sequence counter means no pre-allocated sequence number") } @@ -547,17 +530,17 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { require.Len(t, events, numRequests) // Collect all sequence numbers and verify uniqueness. - seen := make(map[uint64]bool, numRequests) + seen := make(map[int32]bool, numRequests) for i, ev := range events { require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must not be nil", i) - assert.False(t, seen[*ev.SequenceNumber], - "event %d: duplicate sequence number %d", i, *ev.SequenceNumber) - seen[*ev.SequenceNumber] = true + assert.False(t, seen[ev.SequenceNumber], + "event %d: duplicate sequence number %d", i, ev.SequenceNumber) + seen[ev.SequenceNumber] = true } // The set should be exactly {0, 1, ..., numRequests-1}. - for i := uint64(0); i < numRequests; i++ { + for i := int32(0); i < numRequests; i++ { assert.True(t, seen[i], "sequence number %d is missing from the set", i) } @@ -567,11 +550,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { hdr := s.llmBackend.headersAt(i) - seqStr := hdr.Get(config.DefaultSequenceNumberHeaderName) + seqStr := hdr.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true } - for i := uint64(0); i < numRequests; i++ { + for i := int32(0); i < numRequests; i++ { assert.True(t, headerSeqs[fmt.Sprintf("%d", i)], "header sequence number %d is missing", i) } From f1726b5a25085d18cc30040e4792dec7b564d0b8 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:28:24 +0000 Subject: [PATCH 03/15] refactor(proxy): clean up integration test naming and style - Remove '// ---------- Integration Tests ----------' section separator - Rename 'hdr'/'Hdr' variables to 'header'/'Header' for clarity - Rename 'llm'/'llmBackend' to 'injectBackend'/'inject'/'backend' to reflect the actual concept (inject target) rather than a specific use case (LLM) - Update comments to match the new naming Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 226 +++++++++--------- 1 file changed, 112 insertions(+), 114 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 4000270..b2206fb 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -55,17 +55,17 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { // counter. Tests build one via newSessionCorrelationIntegrationSetup // and tear it down with stop. type sessionCorrelationIntegrationSetup struct { - pt *ProxyTest - auditor *capturingAuditor - seq *audit.SequenceCounter - llmBackend *multiRequestCapturingBackend - otherBackend *multiRequestCapturingBackend + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + injectBackend *multiRequestCapturingBackend + otherBackend *multiRequestCapturingBackend } func (s *sessionCorrelationIntegrationSetup) stop() { s.pt.Stop() - if s.llmBackend != nil { - s.llmBackend.close() + if s.injectBackend != nil { + s.injectBackend.close() } if s.otherBackend != nil { s.otherBackend.close() @@ -74,17 +74,16 @@ func (s *sessionCorrelationIntegrationSetup) stop() { // newSessionCorrelationIntegrationSetup builds a proxy that allows // traffic to two httptest backends: one that matches an inject target -// (simulating an LLM provider) and one that does not (simulating a -// generic allowed domain like github.com). Both backends capture all -// received request headers. A capturingAuditor records every audit -// event for later inspection. +// and one that does not (simulating a generic allowed domain like +// github.com). Both backends capture all received request headers. +// A capturingAuditor records every audit event for later inspection. func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { t.Helper() - llm := newMultiRequestCapturingBackend() + inject := newMultiRequestCapturingBackend() other := newMultiRequestCapturingBackend() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -94,18 +93,18 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess seq := &audit.SequenceCounter{} // Both httptest backends resolve to 127.0.0.1, so a domain-only - // inject target would match both. We use a path glob on the LLM - // paths (/v1/*) to limit header injection to LLM requests. + // inject target would match both. We use a path glob on the + // inject-target paths (/v1/*) to limit header injection. pt := NewProxyTest(t, WithCertManager(t.TempDir()), // Allow both backends. - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only requests matching the LLM path receive headers. + // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -114,16 +113,14 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess ).Start() return &sessionCorrelationIntegrationSetup{ - pt: pt, - auditor: aud, - seq: seq, - llmBackend: llm, - otherBackend: other, + pt: pt, + auditor: aud, + seq: seq, + injectBackend: inject, + otherBackend: other, } } -// ---------- Integration Tests ---------- - // TestIntegration_LLMRequestAuditAndHeadersAgree verifies the core // correlation invariant: when an allowed request hits an inject target, // the sequence number in the audit event equals the sequence number in @@ -133,7 +130,7 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { s := newSessionCorrelationIntegrationSetup(t, sessionID) defer s.stop() - resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -146,15 +143,15 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { assert.Equal(t, int32(0), events[0].SequenceNumber) // Forwarded headers. - require.Equal(t, 1, s.llmBackend.requestCount()) - hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, s.injectBackend.requestCount()) + header := s.injectBackend.headersAt(0) + assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) // The two must agree. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - hdr.Get(config.SequenceNumberHeaderName), + header.Get(config.SequenceNumberHeaderName), "audit event and forwarded header must carry the same sequence number", ) } @@ -181,10 +178,10 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // No correlation headers on the backend. require.Equal(t, 1, s.otherBackend.requestCount()) - hdr := s.otherBackend.headersAt(0) - assert.Empty(t, hdr.Get(config.SessionIDHeaderName), + header := s.otherBackend.headersAt(0) + assert.Empty(t, header.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") - assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName), + assert.Empty(t, header.Get(config.SequenceNumberHeaderName), "non-inject-target requests must not carry sequence number header") } @@ -194,8 +191,8 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { // Create a setup with a custom deny-all proxy, but keep the same // pattern of shared sequence counter and auditor. - llm := newMultiRequestCapturingBackend() - defer llm.close() + backend := newMultiRequestCapturingBackend() + defer backend.close() aud := &capturingAuditor{} @@ -211,7 +208,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { ).Start() defer pt.Stop() - resp, err := pt.proxyClient.Get(llm.server.URL + "/exfil") + resp, err := pt.proxyClient.Get(backend.server.URL + "/exfil") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) @@ -224,7 +221,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { assert.Equal(t, int32(0), events[0].SequenceNumber) // Backend never hit. - assert.Equal(t, 0, llm.requestCount(), + assert.Equal(t, 0, backend.requestCount(), "denied requests must not be forwarded to the backend") } @@ -238,14 +235,14 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { const sessionID = "mixed-test-session" - // Two allowed backends (LLM and "github"), one denied domain. - llm := newMultiRequestCapturingBackend() - defer llm.close() + // Two allowed backends (inject target and "github"), one denied domain. + inject := newMultiRequestCapturingBackend() + defer inject.close() other := newMultiRequestCapturingBackend() defer other.close() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -255,13 +252,13 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only LLM is an inject target. + // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -270,13 +267,13 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { ).Start() defer pt.Stop() - // Request 0: LLM (allowed, inject target). - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Request 0: inject target (allowed, headers injected). + resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 1: non-LLM (allowed, no inject). + // Request 1: non-inject-target (allowed, no headers). resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") require.NoError(t, err) resp.Body.Close() //nolint:errcheck @@ -288,8 +285,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Request 3: LLM again. - resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Request 3: inject target again. + resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -308,53 +305,54 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "event %d: wrong allowed flag", i) } - // -- Verify LLM backend headers -- - require.Equal(t, 2, llm.requestCount(), - "LLM backend should have received exactly two requests") + // -- Verify inject-target backend headers -- + require.Equal(t, 2, inject.requestCount(), + "inject-target backend should have received exactly two requests") - firstLLMHdr := llm.headersAt(0) - assert.Equal(t, sessionID, firstLLMHdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", firstLLMHdr.Get(config.SequenceNumberHeaderName), - "first LLM request must have sequence 0") + firstInjectHeader := inject.headersAt(0) + assert.Equal(t, sessionID, firstInjectHeader.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", firstInjectHeader.Get(config.SequenceNumberHeaderName), + "first inject-target request must have sequence 0") - secondLLMHdr := llm.headersAt(1) - assert.Equal(t, sessionID, secondLLMHdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "3", secondLLMHdr.Get(config.SequenceNumberHeaderName), - "second LLM request must have sequence 3") + secondInjectHeader := inject.headersAt(1) + assert.Equal(t, sessionID, secondInjectHeader.Get(config.SessionIDHeaderName)) + assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), + "second inject-target request must have sequence 3") - // -- Verify non-LLM backend has no correlation headers -- + // -- Verify non-inject-target backend has no correlation headers -- require.Equal(t, 1, other.requestCount()) - otherHdr := other.headersAt(0) - assert.Empty(t, otherHdr.Get(config.SessionIDHeaderName)) - assert.Empty(t, otherHdr.Get(config.SequenceNumberHeaderName)) + otherHeader := other.headersAt(0) + assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) + assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) // -- Verify the gap reveals intermediate activity -- - // The gap between the two LLM sequence numbers (0 and 3) means - // that sequence numbers 1 and 2 were consumed by non-LLM - // activity, matching audit events 1 (non-LLM allowed) and 2 - // (denied). - firstLLMSeq := events[0].SequenceNumber - secondLLMSeq := events[3].SequenceNumber - gap := secondLLMSeq - firstLLMSeq - 1 + // The gap between the two inject-target sequence numbers (0 and 3) + // means that sequence numbers 1 and 2 were consumed by + // non-inject-target activity, matching audit events 1 + // (non-inject-target allowed) and 2 (denied). + firstInjectSeq := events[0].SequenceNumber + secondInjectSeq := events[3].SequenceNumber + gap := secondInjectSeq - firstInjectSeq - 1 assert.Equal(t, int32(2), gap, - "gap between LLM requests should reveal 2 intermediate events") + "gap between inject-target requests should reveal 2 intermediate events") } -// TestIntegration_SequenceGapRevealsAgenticLoop sends two LLM requests -// with several non-LLM requests in between, simulating an agentic loop -// where the model triggers tool-use HTTP calls between prompts. The -// test verifies that the gap in LLM sequence numbers precisely -// reflects the count of intermediate boundary events. +// TestIntegration_SequenceGapRevealsAgenticLoop sends two inject-target +// requests with several non-inject-target requests in between, +// simulating an agentic loop where the model triggers tool-use HTTP +// calls between prompts. The test verifies that the gap in +// inject-target sequence numbers precisely reflects the count of +// intermediate boundary events. func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { const sessionID = "agentic-loop-session" - llm := newMultiRequestCapturingBackend() - defer llm.close() + inject := newMultiRequestCapturingBackend() + defer inject.close() other := newMultiRequestCapturingBackend() defer other.close() - llmURL, err := url.Parse(llm.server.URL) + injectURL, err := url.Parse(inject.server.URL) require.NoError(t, err) otherURL, err := url.Parse(other.server.URL) @@ -364,12 +362,12 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []config.InjectTarget{{ - Domain: llmURL.Hostname(), + Domain: injectURL.Hostname(), Path: "/v1/*", }}, }), @@ -378,8 +376,8 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { ).Start() defer pt.Stop() - // First LLM prompt (seq 0). - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // First inject-target request (seq 0). + resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck @@ -390,24 +388,24 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { resp.Body.Close() //nolint:errcheck } - // Second LLM prompt (seq 4). - resp, err = pt.proxyClient.Get(llm.server.URL + "/v1/messages") + // Second inject-target request (seq 4). + resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Verify LLM sequence headers. - require.Equal(t, 2, llm.requestCount()) - assert.Equal(t, "0", llm.headersAt(0).Get(config.SequenceNumberHeaderName)) - assert.Equal(t, "4", llm.headersAt(1).Get(config.SequenceNumberHeaderName)) + // Verify inject-target sequence headers. + require.Equal(t, 2, inject.requestCount()) + assert.Equal(t, "0", inject.headersAt(0).Get(config.SequenceNumberHeaderName)) + assert.Equal(t, "4", inject.headersAt(1).Get(config.SequenceNumberHeaderName)) // The gap between sequence numbers 0 and 4 is 3, matching the // three tool-use requests in between. events := aud.getRequests() require.Len(t, events, 5) - firstLLMSeq := events[0].SequenceNumber - secondLLMSeq := events[4].SequenceNumber - gap := secondLLMSeq - firstLLMSeq - 1 + firstInjectSeq := events[0].SequenceNumber + secondInjectSeq := events[4].SequenceNumber + gap := secondInjectSeq - firstInjectSeq - 1 assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") @@ -428,7 +426,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) s := newSessionCorrelationIntegrationSetup(t, sessionID) defer s.stop() - req, err := http.NewRequest(http.MethodPost, s.llmBackend.server.URL+"/v1/messages", nil) + req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) req.Header.Set(config.SessionIDHeaderName, "spoofed-session") req.Header.Set(config.SequenceNumberHeaderName, "9999") @@ -439,10 +437,10 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) require.Equal(t, http.StatusOK, resp.StatusCode) // Backend received real values, not spoofed. - require.Equal(t, 1, s.llmBackend.requestCount()) - hdr := s.llmBackend.headersAt(0) - assert.Equal(t, sessionID, hdr.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, s.injectBackend.requestCount()) + header := s.injectBackend.headersAt(0) + assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) // Audit event agrees with header. events := s.auditor.getRequests() @@ -450,7 +448,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - hdr.Get(config.SequenceNumberHeaderName), + header.Get(config.SequenceNumberHeaderName), ) } @@ -459,21 +457,21 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // not inject headers and does not pre-allocate sequence numbers (the // auditor falls back to its own counter instead). func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { - llm := newMultiRequestCapturingBackend() - defer llm.close() + backend := newMultiRequestCapturingBackend() + defer backend.close() - llmURL, err := url.Parse(llm.server.URL) + backendURL, err := url.Parse(backend.server.URL) require.NoError(t, err) aud := &capturingAuditor{} pt := NewProxyTest(t, WithCertManager(t.TempDir()), - WithAllowedDomain(llmURL.Hostname()), + WithAllowedDomain(backendURL.Hostname()), // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: llmURL.Hostname()}}, + InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("should-not-appear"), // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. @@ -481,16 +479,16 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi ).Start() defer pt.Stop() - resp, err := pt.proxyClient.Get(llm.server.URL + "/v1/messages") + resp, err := pt.proxyClient.Get(backend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) // No correlation headers. - require.Equal(t, 1, llm.requestCount()) - hdr := llm.headersAt(0) - assert.Empty(t, hdr.Get(config.SessionIDHeaderName)) - assert.Empty(t, hdr.Get(config.SequenceNumberHeaderName)) + require.Equal(t, 1, backend.requestCount()) + header := backend.headersAt(0) + assert.Empty(t, header.Get(config.SessionIDHeaderName)) + assert.Empty(t, header.Get(config.SequenceNumberHeaderName)) // Audit event recorded but without a pre-allocated sequence // number (nil), because no SequenceCounter was provided. @@ -516,7 +514,7 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp, err := s.pt.proxyClient.Get(s.llmBackend.server.URL + "/v1/messages") + resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") assert.NoError(t, err) if resp != nil { resp.Body.Close() //nolint:errcheck @@ -546,11 +544,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { } // Every header should also carry a matching sequence number. - require.Equal(t, numRequests, s.llmBackend.requestCount()) + require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { - hdr := s.llmBackend.headersAt(i) - seqStr := hdr.Get(config.SequenceNumberHeaderName) + header := s.injectBackend.headersAt(i) + seqStr := header.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true } From 4db7db6d2c3a8e5a2d0a230bc7ccb603317c0319 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:34:02 +0000 Subject: [PATCH 04/15] refactor(proxy): improve comments and structure in session correlation integration tests - Enhanced comments for clarity regarding the purpose of `injectBackend` and `otherBackend`. - Removed unnecessary comments to streamline the test code. - Adjusted formatting for consistency and readability in the `sessionCorrelationIntegrationSetup` struct. These changes aim to improve the maintainability and understanding of the integration tests related to session correlation. --- .../proxy_session_correlation_integration_test.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index b2206fb..174e012 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -55,11 +55,16 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { // counter. Tests build one via newSessionCorrelationIntegrationSetup // and tear it down with stop. type sessionCorrelationIntegrationSetup struct { - pt *ProxyTest - auditor *capturingAuditor - seq *audit.SequenceCounter + pt *ProxyTest + auditor *capturingAuditor + seq *audit.SequenceCounter + // llmBackend expects headers to be injected as these requests are + // expected to be seen by the AI Gateway and then correlated back + // to the audit event injectBackend *multiRequestCapturingBackend - otherBackend *multiRequestCapturingBackend + // otherBackend does not expect headers to be injected as these + // requests should not be routed through the AI Gateway. + otherBackend *multiRequestCapturingBackend } func (s *sessionCorrelationIntegrationSetup) stop() { @@ -135,7 +140,6 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Audit event. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) @@ -148,7 +152,6 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) - // The two must agree. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), header.Get(config.SequenceNumberHeaderName), From b84170e622c96526d8f0d0bd7c463ed01e049412 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 12:42:13 +0000 Subject: [PATCH 05/15] refactor(proxy): rename setup struct, fix disabled-correlation test - Rename sessionCorrelationIntegrationSetup to correlationTestEnv and newSessionCorrelationIntegrationSetup to newCorrelationTestEnv for brevity and clarity. - Rename TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence to TestIntegration_DisabledCorrelationNoHeaders. The sequence counter is a value type on the proxy server and always increments regardless of the correlation setting, so the previous name and assertions about 'no pre-allocated sequence number' were misleading. The test now focuses on what actually differs: no headers are injected. - Remove misleading 'auditor falls back to its own counter' comment. Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 174e012..17b650c 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -50,24 +50,24 @@ func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { return m.all[i].Clone() } -// sessionCorrelationIntegrationSetup holds the shared objects for an +// correlationTestEnv holds the shared objects for a session-correlation // integration test: the proxy, auditor, backend(s), and sequence -// counter. Tests build one via newSessionCorrelationIntegrationSetup -// and tear it down with stop. -type sessionCorrelationIntegrationSetup struct { +// counter. Tests build one via newCorrelationTestEnv and tear it down +// with stop. +type correlationTestEnv struct { pt *ProxyTest auditor *capturingAuditor seq *audit.SequenceCounter - // llmBackend expects headers to be injected as these requests are - // expected to be seen by the AI Gateway and then correlated back - // to the audit event + // injectBackend expects headers to be injected as these requests + // are expected to be seen by the AI Gateway and then correlated + // back to the audit event. injectBackend *multiRequestCapturingBackend // otherBackend does not expect headers to be injected as these // requests should not be routed through the AI Gateway. otherBackend *multiRequestCapturingBackend } -func (s *sessionCorrelationIntegrationSetup) stop() { +func (s *correlationTestEnv) stop() { s.pt.Stop() if s.injectBackend != nil { s.injectBackend.close() @@ -77,12 +77,12 @@ func (s *sessionCorrelationIntegrationSetup) stop() { } } -// newSessionCorrelationIntegrationSetup builds a proxy that allows -// traffic to two httptest backends: one that matches an inject target -// and one that does not (simulating a generic allowed domain like -// github.com). Both backends capture all received request headers. -// A capturingAuditor records every audit event for later inspection. -func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sessionCorrelationIntegrationSetup { +// newCorrelationTestEnv builds a proxy that allows traffic to two +// httptest backends: one that matches an inject target and one that +// does not (simulating a generic allowed domain like github.com). +// Both backends capture all received request headers. A +// capturingAuditor records every audit event for later inspection. +func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { t.Helper() inject := newMultiRequestCapturingBackend() @@ -117,7 +117,7 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess WithAuditor(aud), ).Start() - return &sessionCorrelationIntegrationSetup{ + return &correlationTestEnv{ pt: pt, auditor: aud, seq: seq, @@ -132,7 +132,7 @@ func newSessionCorrelationIntegrationSetup(t *testing.T, sessionID string) *sess // the forwarded header. func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") @@ -164,7 +164,7 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // audited (with a sequence number) but does NOT receive correlation // headers. func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { - s := newSessionCorrelationIntegrationSetup(t, "test-session") + s := newCorrelationTestEnv(t, "test-session") defer s.stop() resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") @@ -426,7 +426,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // sequence number, and the audit event still agrees with the header. func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { const sessionID = "real-session-uuid" - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) @@ -455,11 +455,13 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) ) } -// TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence -// verifies that when session correlation is disabled, the proxy does -// not inject headers and does not pre-allocate sequence numbers (the -// auditor falls back to its own counter instead). -func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testing.T) { +// TestIntegration_DisabledCorrelationNoHeaders verifies that when +// session correlation is disabled, the proxy does not inject +// correlation headers even for requests that match an inject target. +// Note: the sequence counter is a value type on the proxy server and +// always increments regardless of the correlation setting, so we only +// assert on the absence of headers here. +func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { backend := newMultiRequestCapturingBackend() defer backend.close() @@ -471,13 +473,11 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi pt := NewProxyTest(t, WithCertManager(t.TempDir()), WithAllowedDomain(backendURL.Hostname()), - // Correlation disabled; no sequence counter. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, }), WithSessionID("should-not-appear"), - // Explicitly do NOT set WithSequenceCounter; seqCounter is nil. WithAuditor(aud), ).Start() defer pt.Stop() @@ -487,18 +487,18 @@ func TestIntegration_DisabledCorrelationNoHeadersNoPreallocatedSequence(t *testi defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // No correlation headers. + // No correlation headers injected. require.Equal(t, 1, backend.requestCount()) header := backend.headersAt(0) - assert.Empty(t, header.Get(config.SessionIDHeaderName)) - assert.Empty(t, header.Get(config.SequenceNumberHeaderName)) + assert.Empty(t, header.Get(config.SessionIDHeaderName), + "session ID header must not be injected when correlation is disabled") + assert.Empty(t, header.Get(config.SequenceNumberHeaderName), + "sequence number header must not be injected when correlation is disabled") - // Audit event recorded but without a pre-allocated sequence - // number (nil), because no SequenceCounter was provided. + // Request is still audited. events := aud.getRequests() require.Len(t, events, 1) - assert.Equal(t, int32(0), events[0].SequenceNumber, - "no sequence counter means no pre-allocated sequence number") + require.True(t, events[0].Allowed) } // TestIntegration_ConcurrentRequestsUniqueSequenceNumbers sends @@ -509,7 +509,7 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { const sessionID = "concurrent-session" const numRequests = 10 - s := newSessionCorrelationIntegrationSetup(t, sessionID) + s := newCorrelationTestEnv(t, sessionID) defer s.stop() var wg sync.WaitGroup From 49321bddcea1557c9105edb7885f0f90cd122dd8 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 13:28:15 +0000 Subject: [PATCH 06/15] fix(proxy): sync with main branch InjectTarget changes PR #201 was merged to main, replacing config.InjectTarget struct with []string rule specs (rulesengine syntax). Update the integration test file to use the []string format, and sync config/, proxy/proxy.go, and other test files from main to fix the build. Generated by Coder Agents --- ...xy_session_correlation_integration_test.go | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 17b650c..5baef04 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -108,10 +108,7 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -204,7 +201,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{Domain: "anything.example.com"}}, + InjectTargets: []string{"domain=anything.example.com"}, }), WithSessionID("test-session"), WithAuditor(aud), @@ -260,10 +257,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -369,10 +363,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, - InjectTargets: []config.InjectTarget{{ - Domain: injectURL.Hostname(), - Path: "/v1/*", - }}, + InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), WithAuditor(aud), @@ -475,7 +466,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { WithAllowedDomain(backendURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: false, - InjectTargets: []config.InjectTarget{{Domain: backendURL.Hostname()}}, + InjectTargets: []string{"domain=" + backendURL.Hostname()}, }), WithSessionID("should-not-appear"), WithAuditor(aud), From f578460773fd4f6ed1453e43822fb0372ed10b50 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Wed, 13 May 2026 13:31:34 +0000 Subject: [PATCH 07/15] make fmt --- proxy/proxy_session_correlation_integration_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 5baef04..3f36bb1 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -107,7 +107,7 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { WithAllowedDomain(otherURL.Hostname()), // Only requests matching the inject-target path receive headers. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), @@ -256,7 +256,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { WithAllowedDomain(otherURL.Hostname()), // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), @@ -362,7 +362,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, + Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, }), WithSessionID(sessionID), From a9d59ec67b877527bad759a50722e82a6af99d11 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 14 May 2026 17:24:27 +0000 Subject: [PATCH 08/15] test(proxy): use Given/When/Then style comments in integration tests --- ...xy_session_correlation_integration_test.go | 81 ++++++++++--------- 1 file changed, 44 insertions(+), 37 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 3f36bb1..c21908c 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -128,27 +128,31 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { // the sequence number in the audit event equals the sequence number in // the forwarded header. func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { + // Given: a proxy with session correlation enabled and an inject-target backend. const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: a single request is sent to the inject-target backend. resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) + // Then: the audit event records the correct sequence number. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // Forwarded headers. + // Then: the forwarded request carries the session ID and sequence number headers. require.Equal(t, 1, s.injectBackend.requestCount()) header := s.injectBackend.headersAt(0) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) + // Then: the audit event and forwarded header agree on the sequence number. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), header.Get(config.SequenceNumberHeaderName), @@ -161,22 +165,24 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // audited (with a sequence number) but does NOT receive correlation // headers. func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { + // Given: a proxy with session correlation enabled and a non-inject-target backend. s := newCorrelationTestEnv(t, "test-session") defer s.stop() + // When: a request is sent to the non-inject-target backend. resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Audit event recorded. + // Then: an audit event is recorded with a sequence number. events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // No correlation headers on the backend. + // Then: no correlation headers are present on the forwarded request. require.Equal(t, 1, s.otherBackend.requestCount()) header := s.otherBackend.headersAt(0) assert.Empty(t, header.Get(config.SessionIDHeaderName), @@ -189,8 +195,7 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // request denied by the rules engine is audited (consuming a sequence // number) but is never forwarded to any backend. func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { - // Create a setup with a custom deny-all proxy, but keep the same - // pattern of shared sequence counter and auditor. + // Given: a proxy with no allowed domains (deny-all configuration). backend := newMultiRequestCapturingBackend() defer backend.close() @@ -198,7 +203,6 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { pt := NewProxyTest(t, WithCertManager(t.TempDir()), - // No allowed domains: deny everything. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []string{"domain=anything.example.com"}, @@ -208,19 +212,20 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { ).Start() defer pt.Stop() + // When: a request is sent to a domain that is not allowed. resp, err := pt.proxyClient.Get(backend.server.URL + "/exfil") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Audit event recorded. + // Then: an audit event is recorded with the denied flag and a sequence number. events := aud.getRequests() require.Len(t, events, 1) require.False(t, events[0].Allowed) require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) - // Backend never hit. + // Then: the backend never receives the request. assert.Equal(t, 0, backend.requestCount(), "denied requests must not be forwarded to the backend") } @@ -235,7 +240,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { const sessionID = "mixed-test-session" - // Two allowed backends (inject target and "github"), one denied domain. + // Given: a proxy with an inject-target backend and a non-inject-target backend. inject := newMultiRequestCapturingBackend() defer inject.close() @@ -254,7 +259,6 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { WithCertManager(t.TempDir()), WithAllowedDomain(injectURL.Hostname()), WithAllowedDomain(otherURL.Hostname()), - // Only the inject backend is an inject target. WithSessionCorrelation(config.SessionCorrelationConfig{ Enabled: true, InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, @@ -264,31 +268,30 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { ).Start() defer pt.Stop() - // Request 0: inject target (allowed, headers injected). + // When: an inject-target, non-inject-target, denied, and inject-target + // request are sent in sequence. resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 1: non-inject-target (allowed, no headers). resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Request 2: denied (nothing is allowed for evil.example.com). resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - // Request 3: inject target again. resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // -- Verify audit events -- + // Then: all four requests produce audit events with monotonically + // increasing sequence numbers. events := aud.getRequests() require.Len(t, events, 4, "expected exactly four audit events") @@ -302,7 +305,8 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { "event %d: wrong allowed flag", i) } - // -- Verify inject-target backend headers -- + // Then: the inject-target backend receives correlation headers with + // the correct sequence numbers. require.Equal(t, 2, inject.requestCount(), "inject-target backend should have received exactly two requests") @@ -316,17 +320,14 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), "second inject-target request must have sequence 3") - // -- Verify non-inject-target backend has no correlation headers -- + // Then: the non-inject-target backend receives no correlation headers. require.Equal(t, 1, other.requestCount()) otherHeader := other.headersAt(0) assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) - // -- Verify the gap reveals intermediate activity -- - // The gap between the two inject-target sequence numbers (0 and 3) - // means that sequence numbers 1 and 2 were consumed by - // non-inject-target activity, matching audit events 1 - // (non-inject-target allowed) and 2 (denied). + // Then: the gap between inject-target sequence numbers (0 and 3) + // reveals 2 intermediate events (non-inject-target allowed and denied). firstInjectSeq := events[0].SequenceNumber secondInjectSeq := events[3].SequenceNumber gap := secondInjectSeq - firstInjectSeq - 1 @@ -343,6 +344,7 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { const sessionID = "agentic-loop-session" + // Given: a proxy with an inject-target and a non-inject-target backend. inject := newMultiRequestCapturingBackend() defer inject.close() @@ -370,30 +372,29 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { ).Start() defer pt.Stop() - // First inject-target request (seq 0). + // When: an inject-target request, three tool-use requests to the + // non-inject-target backend, and another inject-target request are + // sent in sequence. resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Agentic loop: three tool-use HTTP calls. for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { resp, err = pt.proxyClient.Get(other.server.URL + p) require.NoError(t, err) resp.Body.Close() //nolint:errcheck } - // Second inject-target request (seq 4). resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck - // Verify inject-target sequence headers. + // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) assert.Equal(t, "0", inject.headersAt(0).Get(config.SequenceNumberHeaderName)) assert.Equal(t, "4", inject.headersAt(1).Get(config.SequenceNumberHeaderName)) - // The gap between sequence numbers 0 and 4 is 3, matching the - // three tool-use requests in between. + // Then: the gap of 3 matches the three intermediate tool-use requests. events := aud.getRequests() require.Len(t, events, 5) @@ -403,7 +404,7 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { assert.Equal(t, int32(3), gap, "gap between prompts should equal number of tool-use requests") - // Verify the intermediate events are the tool-use requests. + // Then: the intermediate audit events correspond to the tool-use requests. for i := 1; i <= 3; i++ { require.NotNil(t, events[i].SequenceNumber) assert.Equal(t, int32(i), events[i].SequenceNumber) @@ -416,10 +417,12 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // the proxy replaces them with the real session ID and the real // sequence number, and the audit event still agrees with the header. func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) { + // Given: a proxy with session correlation enabled. const sessionID = "real-session-uuid" s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: a request is sent with spoofed correlation headers. req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) req.Header.Set(config.SessionIDHeaderName, "spoofed-session") @@ -430,13 +433,13 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // Backend received real values, not spoofed. + // Then: the backend receives the real values, not the spoofed ones. require.Equal(t, 1, s.injectBackend.requestCount()) header := s.injectBackend.headersAt(0) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) - // Audit event agrees with header. + // Then: the audit event agrees with the forwarded header. events := s.auditor.getRequests() require.Len(t, events, 1) require.NotNil(t, events[0].SequenceNumber) @@ -453,6 +456,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // always increments regardless of the correlation setting, so we only // assert on the absence of headers here. func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { + // Given: a proxy with session correlation disabled. backend := newMultiRequestCapturingBackend() defer backend.close() @@ -473,12 +477,13 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { ).Start() defer pt.Stop() + // When: a request is sent that would match an inject target. resp, err := pt.proxyClient.Get(backend.server.URL + "/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - // No correlation headers injected. + // Then: no correlation headers are injected on the forwarded request. require.Equal(t, 1, backend.requestCount()) header := backend.headersAt(0) assert.Empty(t, header.Get(config.SessionIDHeaderName), @@ -486,7 +491,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { assert.Empty(t, header.Get(config.SequenceNumberHeaderName), "sequence number header must not be injected when correlation is disabled") - // Request is still audited. + // Then: the request is still audited. events := aud.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) @@ -500,9 +505,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { const sessionID = "concurrent-session" const numRequests = 10 + // Given: a proxy with session correlation enabled. s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // When: multiple requests are sent concurrently to the inject-target backend. var wg sync.WaitGroup for i := 0; i < numRequests; i++ { wg.Add(1) @@ -517,11 +524,11 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { } wg.Wait() - // Every request should have been audited. + // Then: every request is audited. events := s.auditor.getRequests() require.Len(t, events, numRequests) - // Collect all sequence numbers and verify uniqueness. + // Then: each audit event has a unique sequence number. seen := make(map[int32]bool, numRequests) for i, ev := range events { require.NotNil(t, ev.SequenceNumber, @@ -531,13 +538,13 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { seen[ev.SequenceNumber] = true } - // The set should be exactly {0, 1, ..., numRequests-1}. + // Then: the sequence numbers form a dense set {0, 1, ..., numRequests-1}. for i := int32(0); i < numRequests; i++ { assert.True(t, seen[i], "sequence number %d is missing from the set", i) } - // Every header should also carry a matching sequence number. + // Then: every forwarded request header carries a matching sequence number. require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { From ecc0a5882e97ad326d6a9827ae25eea9ab1fc3d8 Mon Sep 17 00:00:00 2001 From: "doc-check[bot]" Date: Mon, 18 May 2026 09:00:38 +0000 Subject: [PATCH 09/15] fix(proxy): address review feedback on session correlation integration tests - Remove redundant require.NotNil assertions on int32 SequenceNumber fields (value types cannot be nil) - Assert audit events are empty before the request in LLMRequestAuditAndHeadersAgree to prove causality - Add missing HTTP status code checks in SequenceGapRevealsAgenticLoop - Use context-aware HTTP requests throughout via doGet helper and http.NewRequestWithContext with t.Context() Co-authored-by: Coder Agents --- ...xy_session_correlation_integration_test.go | 52 +++++++++++-------- 1 file changed, 31 insertions(+), 21 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index c21908c..4688718 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -15,6 +15,18 @@ import ( "github.com/stretchr/testify/require" ) +// doGet is a test helper that creates a context-aware GET request and +// executes it. The context is derived from the test so that in-flight +// requests are cancelled when the test finishes. +func doGet(t *testing.T, client *http.Client, url string) (*http.Response, error) { + t.Helper() + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, url, nil) + if err != nil { + return nil, err + } + return client.Do(req) +} + // multiRequestCapturingBackend records the headers from every request it // receives, not just the last one. This is needed by integration tests // that send multiple requests to the same backend and want to verify @@ -133,8 +145,11 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { s := newCorrelationTestEnv(t, sessionID) defer s.stop() + // Precondition: no audit events exist before the request. + require.Empty(t, s.auditor.getRequests(), "no audit events should exist before the request") + // When: a single request is sent to the inject-target backend. - resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") + resp, err := doGet(t, s.pt.proxyClient, s.injectBackend.server.URL+"/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -143,7 +158,6 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) - require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) // Then: the forwarded request carries the session ID and sequence number headers. @@ -170,7 +184,7 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { defer s.stop() // When: a request is sent to the non-inject-target backend. - resp, err := s.pt.proxyClient.Get(s.otherBackend.server.URL + "/pulls") + resp, err := doGet(t, s.pt.proxyClient, s.otherBackend.server.URL+"/pulls") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -179,7 +193,6 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { events := s.auditor.getRequests() require.Len(t, events, 1) require.True(t, events[0].Allowed) - require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) // Then: no correlation headers are present on the forwarded request. @@ -213,7 +226,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { defer pt.Stop() // When: a request is sent to a domain that is not allowed. - resp, err := pt.proxyClient.Get(backend.server.URL + "/exfil") + resp, err := doGet(t, pt.proxyClient, backend.server.URL+"/exfil") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) @@ -222,7 +235,6 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { events := aud.getRequests() require.Len(t, events, 1) require.False(t, events[0].Allowed) - require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, int32(0), events[0].SequenceNumber) // Then: the backend never receives the request. @@ -270,22 +282,22 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { // When: an inject-target, non-inject-target, denied, and inject-target // request are sent in sequence. - resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") + resp, err := doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - resp, err = pt.proxyClient.Get(other.server.URL + "/coder/coder") + resp, err = doGet(t, pt.proxyClient, other.server.URL+"/coder/coder") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) - resp, err = pt.proxyClient.Get("http://evil.example.com/exfil") + resp, err = doGet(t, pt.proxyClient, "http://evil.example.com/exfil") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusForbidden, resp.StatusCode) - resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") + resp, err = doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -298,7 +310,6 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { expectedSeq := []int32{0, 1, 2, 3} expectedAllowed := []bool{true, true, false, true} for i, ev := range events { - require.NotNil(t, ev.SequenceNumber, "event %d: sequence number must be set", i) assert.Equal(t, expectedSeq[i], ev.SequenceNumber, "event %d: wrong sequence number", i) assert.Equal(t, expectedAllowed[i], ev.Allowed, @@ -375,19 +386,22 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // When: an inject-target request, three tool-use requests to the // non-inject-target backend, and another inject-target request are // sent in sequence. - resp, err := pt.proxyClient.Get(inject.server.URL + "/v1/messages") + resp, err := doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { - resp, err = pt.proxyClient.Get(other.server.URL + p) + resp, err = doGet(t, pt.proxyClient, other.server.URL+p) require.NoError(t, err) resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) } - resp, err = pt.proxyClient.Get(inject.server.URL + "/v1/messages") + resp, err = doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") require.NoError(t, err) resp.Body.Close() //nolint:errcheck + require.Equal(t, http.StatusOK, resp.StatusCode) // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) @@ -406,7 +420,6 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // Then: the intermediate audit events correspond to the tool-use requests. for i := 1; i <= 3; i++ { - require.NotNil(t, events[i].SequenceNumber) assert.Equal(t, int32(i), events[i].SequenceNumber) assert.True(t, events[i].Allowed) } @@ -423,7 +436,7 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) defer s.stop() // When: a request is sent with spoofed correlation headers. - req, err := http.NewRequest(http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) + req, err := http.NewRequestWithContext(t.Context(), http.MethodPost, s.injectBackend.server.URL+"/v1/messages", nil) require.NoError(t, err) req.Header.Set(config.SessionIDHeaderName, "spoofed-session") req.Header.Set(config.SequenceNumberHeaderName, "9999") @@ -442,7 +455,6 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // Then: the audit event agrees with the forwarded header. events := s.auditor.getRequests() require.Len(t, events, 1) - require.NotNil(t, events[0].SequenceNumber) assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), header.Get(config.SequenceNumberHeaderName), @@ -478,7 +490,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { defer pt.Stop() // When: a request is sent that would match an inject target. - resp, err := pt.proxyClient.Get(backend.server.URL + "/v1/messages") + resp, err := doGet(t, pt.proxyClient, backend.server.URL+"/v1/messages") require.NoError(t, err) defer resp.Body.Close() //nolint:errcheck require.Equal(t, http.StatusOK, resp.StatusCode) @@ -515,7 +527,7 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp, err := s.pt.proxyClient.Get(s.injectBackend.server.URL + "/v1/messages") + resp, err := doGet(t, s.pt.proxyClient, s.injectBackend.server.URL+"/v1/messages") assert.NoError(t, err) if resp != nil { resp.Body.Close() //nolint:errcheck @@ -531,8 +543,6 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { // Then: each audit event has a unique sequence number. seen := make(map[int32]bool, numRequests) for i, ev := range events { - require.NotNil(t, ev.SequenceNumber, - "event %d: sequence number must not be nil", i) assert.False(t, seen[ev.SequenceNumber], "event %d: duplicate sequence number %d", i, ev.SequenceNumber) seen[ev.SequenceNumber] = true From 7339ad53cbd836d0c09fa7d96478c7500e8b0ac6 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Mon, 18 May 2026 09:42:11 +0000 Subject: [PATCH 10/15] feat(proxy): add ExpectGetViaProxy method for context-aware GET requests - Introduced ExpectGetViaProxy method to streamline context-bound GET requests through the proxy, enhancing test reliability. - Updated integration tests to utilize ExpectGetViaProxy, improving readability and consistency in handling HTTP requests. - Removed redundant doGet helper function as its functionality is now encapsulated in ExpectGetViaProxy. These changes aim to improve the maintainability and clarity of the proxy integration tests. --- proxy/proxy_framework_test.go | 14 ++++ ...xy_session_correlation_integration_test.go | 74 ++++--------------- 2 files changed, 28 insertions(+), 60 deletions(-) diff --git a/proxy/proxy_framework_test.go b/proxy/proxy_framework_test.go index 87edf26..bfc69e1 100644 --- a/proxy/proxy_framework_test.go +++ b/proxy/proxy_framework_test.go @@ -317,6 +317,20 @@ func (pt *ProxyTest) ExpectDenyViaProxy(targetURL string) { require.Contains(pt.t, string(body), "Request Blocked by Boundary", "Expected request to be blocked") } +// ExpectGetViaProxy makes a context-bound GET request through the proxy +// and fails the test immediately if the transport errors or the +// response status does not match wantStatus. The response body is +// drained and closed before returning. +func (pt *ProxyTest) ExpectGetViaProxy(targetURL string, wantStatus int) { + pt.t.Helper() + req, err := http.NewRequestWithContext(pt.t.Context(), http.MethodGet, targetURL, nil) + require.NoError(pt.t, err) + resp, err := pt.proxyClient.Do(req) + require.NoError(pt.t, err) + defer resp.Body.Close() //nolint:errcheck + require.Equal(pt.t, wantStatus, resp.StatusCode) +} + // ExpectAllowedViaProxy makes a request through the proxy using proxy transport (implicit CONNECT for HTTPS) // and expects it to be allowed with the given response body func (pt *ProxyTest) ExpectAllowedViaProxy(targetURL, expectedBody string) { diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 4688718..db153ef 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -15,18 +15,6 @@ import ( "github.com/stretchr/testify/require" ) -// doGet is a test helper that creates a context-aware GET request and -// executes it. The context is derived from the test so that in-flight -// requests are cancelled when the test finishes. -func doGet(t *testing.T, client *http.Client, url string) (*http.Response, error) { - t.Helper() - req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, url, nil) - if err != nil { - return nil, err - } - return client.Do(req) -} - // multiRequestCapturingBackend records the headers from every request it // receives, not just the last one. This is needed by integration tests // that send multiple requests to the same backend and want to verify @@ -149,10 +137,7 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { require.Empty(t, s.auditor.getRequests(), "no audit events should exist before the request") // When: a single request is sent to the inject-target backend. - resp, err := doGet(t, s.pt.proxyClient, s.injectBackend.server.URL+"/v1/messages") - require.NoError(t, err) - defer resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + s.pt.ExpectGetViaProxy(s.injectBackend.server.URL+"/v1/messages", http.StatusOK) // Then: the audit event records the correct sequence number. events := s.auditor.getRequests() @@ -184,10 +169,7 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { defer s.stop() // When: a request is sent to the non-inject-target backend. - resp, err := doGet(t, s.pt.proxyClient, s.otherBackend.server.URL+"/pulls") - require.NoError(t, err) - defer resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + s.pt.ExpectGetViaProxy(s.otherBackend.server.URL+"/pulls", http.StatusOK) // Then: an audit event is recorded with a sequence number. events := s.auditor.getRequests() @@ -226,10 +208,7 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { defer pt.Stop() // When: a request is sent to a domain that is not allowed. - resp, err := doGet(t, pt.proxyClient, backend.server.URL+"/exfil") - require.NoError(t, err) - defer resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusForbidden, resp.StatusCode) + pt.ExpectGetViaProxy(backend.server.URL+"/exfil", http.StatusForbidden) // Then: an audit event is recorded with the denied flag and a sequence number. events := aud.getRequests() @@ -282,25 +261,10 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { // When: an inject-target, non-inject-target, denied, and inject-target // request are sent in sequence. - resp, err := doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) - - resp, err = doGet(t, pt.proxyClient, other.server.URL+"/coder/coder") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) - - resp, err = doGet(t, pt.proxyClient, "http://evil.example.com/exfil") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusForbidden, resp.StatusCode) - - resp, err = doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) + pt.ExpectGetViaProxy(other.server.URL+"/coder/coder", http.StatusOK) + pt.ExpectGetViaProxy("http://evil.example.com/exfil", http.StatusForbidden) + pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) // Then: all four requests produce audit events with monotonically // increasing sequence numbers. @@ -386,22 +350,13 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // When: an inject-target request, three tool-use requests to the // non-inject-target backend, and another inject-target request are // sent in sequence. - resp, err := doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { - resp, err = doGet(t, pt.proxyClient, other.server.URL+p) - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + pt.ExpectGetViaProxy(other.server.URL+p, http.StatusOK) } - resp, err = doGet(t, pt.proxyClient, inject.server.URL+"/v1/messages") - require.NoError(t, err) - resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) @@ -490,10 +445,7 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { defer pt.Stop() // When: a request is sent that would match an inject target. - resp, err := doGet(t, pt.proxyClient, backend.server.URL+"/v1/messages") - require.NoError(t, err) - defer resp.Body.Close() //nolint:errcheck - require.Equal(t, http.StatusOK, resp.StatusCode) + pt.ExpectGetViaProxy(backend.server.URL+"/v1/messages", http.StatusOK) // Then: no correlation headers are injected on the forwarded request. require.Equal(t, 1, backend.requestCount()) @@ -527,7 +479,9 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - resp, err := doGet(t, s.pt.proxyClient, s.injectBackend.server.URL+"/v1/messages") + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, s.injectBackend.server.URL+"/v1/messages", nil) + assert.NoError(t, err) + resp, err := s.pt.proxyClient.Do(req) assert.NoError(t, err) if resp != nil { resp.Body.Close() //nolint:errcheck From 9057c901d70c9cb8d75db83fdb43974685e219d0 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 21 May 2026 10:00:59 +0000 Subject: [PATCH 11/15] fix(proxy): enhance error handling in headersAt method of multiRequestCapturingBackend - Updated headersAt method to return an error when the index is out of range, improving robustness. - Adjusted integration tests to handle the new error return, ensuring proper assertions and error checks. - This change enhances the reliability of session correlation tests by preventing potential panics from invalid index access. --- ...xy_session_correlation_integration_test.go | 42 +++++++++++++------ 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index db153ef..70d2d19 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -1,6 +1,7 @@ package proxy import ( + "errors" "fmt" "net/http" "net/http/httptest" @@ -36,6 +37,8 @@ func newMultiRequestCapturingBackend() *multiRequestCapturingBackend { return mcb } +var errHeaderIndexOutOfRange = errors.New("headersAt: index out of range") + func (m *multiRequestCapturingBackend) close() { m.server.Close() } func (m *multiRequestCapturingBackend) requestCount() int { @@ -44,10 +47,13 @@ func (m *multiRequestCapturingBackend) requestCount() int { return len(m.all) } -func (m *multiRequestCapturingBackend) headersAt(i int) http.Header { +func (m *multiRequestCapturingBackend) headersAt(i int) (http.Header, error) { m.mu.Lock() defer m.mu.Unlock() - return m.all[i].Clone() + if i < 0 || i >= len(m.all) { + return nil, errHeaderIndexOutOfRange + } + return m.all[i].Clone(), nil } // correlationTestEnv holds the shared objects for a session-correlation @@ -147,7 +153,8 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // Then: the forwarded request carries the session ID and sequence number headers. require.Equal(t, 1, s.injectBackend.requestCount()) - header := s.injectBackend.headersAt(0) + header, err := s.injectBackend.headersAt(0) + require.NoError(t, err) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) @@ -179,7 +186,8 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // Then: no correlation headers are present on the forwarded request. require.Equal(t, 1, s.otherBackend.requestCount()) - header := s.otherBackend.headersAt(0) + header, err := s.otherBackend.headersAt(0) + require.NoError(t, err) assert.Empty(t, header.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") assert.Empty(t, header.Get(config.SequenceNumberHeaderName), @@ -285,19 +293,22 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { require.Equal(t, 2, inject.requestCount(), "inject-target backend should have received exactly two requests") - firstInjectHeader := inject.headersAt(0) + firstInjectHeader, err := inject.headersAt(0) + require.NoError(t, err) assert.Equal(t, sessionID, firstInjectHeader.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", firstInjectHeader.Get(config.SequenceNumberHeaderName), "first inject-target request must have sequence 0") - secondInjectHeader := inject.headersAt(1) + secondInjectHeader, err := inject.headersAt(1) + require.NoError(t, err) assert.Equal(t, sessionID, secondInjectHeader.Get(config.SessionIDHeaderName)) assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), "second inject-target request must have sequence 3") // Then: the non-inject-target backend receives no correlation headers. require.Equal(t, 1, other.requestCount()) - otherHeader := other.headersAt(0) + otherHeader, err := other.headersAt(0) + require.NoError(t, err) assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) @@ -360,8 +371,12 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) - assert.Equal(t, "0", inject.headersAt(0).Get(config.SequenceNumberHeaderName)) - assert.Equal(t, "4", inject.headersAt(1).Get(config.SequenceNumberHeaderName)) + firstHeader, err := inject.headersAt(0) + require.NoError(t, err) + assert.Equal(t, "0", firstHeader.Get(config.SequenceNumberHeaderName)) + secondHeader, err := inject.headersAt(1) + require.NoError(t, err) + assert.Equal(t, "4", secondHeader.Get(config.SequenceNumberHeaderName)) // Then: the gap of 3 matches the three intermediate tool-use requests. events := aud.getRequests() @@ -403,7 +418,8 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // Then: the backend receives the real values, not the spoofed ones. require.Equal(t, 1, s.injectBackend.requestCount()) - header := s.injectBackend.headersAt(0) + header, err := s.injectBackend.headersAt(0) + require.NoError(t, err) assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) @@ -449,7 +465,8 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { // Then: no correlation headers are injected on the forwarded request. require.Equal(t, 1, backend.requestCount()) - header := backend.headersAt(0) + header, err := backend.headersAt(0) + require.NoError(t, err) assert.Empty(t, header.Get(config.SessionIDHeaderName), "session ID header must not be injected when correlation is disabled") assert.Empty(t, header.Get(config.SequenceNumberHeaderName), @@ -512,7 +529,8 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { - header := s.injectBackend.headersAt(i) + header, err := s.injectBackend.headersAt(i) + require.NoError(t, err) seqStr := header.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true From 525ef137c9a2322aac761cddff1200596bfaf702 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 21 May 2026 10:03:01 +0000 Subject: [PATCH 12/15] fix(proxy): improve stop method in correlation test environment - Added nil checks for the proxy target before stopping it to prevent potential nil pointer dereference errors. - This change enhances the stability of the correlation test environment by ensuring safe resource cleanup. --- proxy/proxy_session_correlation_integration_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 70d2d19..e797657 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -74,7 +74,9 @@ type correlationTestEnv struct { } func (s *correlationTestEnv) stop() { - s.pt.Stop() + if s.pt != nil { + s.pt.Stop() + } if s.injectBackend != nil { s.injectBackend.close() } From 6952ad07382b05d7ba6c3ea89ab9bd74f49e1d11 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 21 May 2026 10:15:47 +0000 Subject: [PATCH 13/15] fix(proxy): update session ID in integration test for consistency --- proxy/proxy_session_correlation_integration_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index e797657..83888fa 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -137,7 +137,7 @@ func newCorrelationTestEnv(t *testing.T, sessionID string) *correlationTestEnv { // the forwarded header. func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // Given: a proxy with session correlation enabled and an inject-target backend. - const sessionID = "e5f6a7b8-0000-0000-0000-000000000000" + const sessionID = "e5f6a7b8-c9d0-4e1f-8a2b-3c4d5e6f7a8b" s := newCorrelationTestEnv(t, sessionID) defer s.stop() From 0ed9c628b4f738be6b87e5a1153ddce333ac0ab0 Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 21 May 2026 12:39:53 +0000 Subject: [PATCH 14/15] fix(proxy): update integration tests to use consistent header variable naming - Renamed header variables to headers for clarity and consistency across multiple integration tests. - Added nil checks for headers to ensure they are not nil before assertions, enhancing test robustness. - These changes improve the readability and reliability of the session correlation integration tests. --- ...xy_session_correlation_integration_test.go | 68 +++++++++++-------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 83888fa..4d4fd07 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -155,15 +155,16 @@ func TestIntegration_LLMRequestAuditAndHeadersAgree(t *testing.T) { // Then: the forwarded request carries the session ID and sequence number headers. require.Equal(t, 1, s.injectBackend.requestCount()) - header, err := s.injectBackend.headersAt(0) + headers, err := s.injectBackend.headersAt(0) require.NoError(t, err) - assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) + require.NotNil(t, headers) + assert.Equal(t, sessionID, headers.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", headers.Get(config.SequenceNumberHeaderName)) // Then: the audit event and forwarded header agree on the sequence number. assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - header.Get(config.SequenceNumberHeaderName), + headers.Get(config.SequenceNumberHeaderName), "audit event and forwarded header must carry the same sequence number", ) } @@ -188,11 +189,12 @@ func TestIntegration_NonLLMRequestAuditedWithoutHeaders(t *testing.T) { // Then: no correlation headers are present on the forwarded request. require.Equal(t, 1, s.otherBackend.requestCount()) - header, err := s.otherBackend.headersAt(0) + headers, err := s.otherBackend.headersAt(0) require.NoError(t, err) - assert.Empty(t, header.Get(config.SessionIDHeaderName), + require.NotNil(t, headers) + assert.Empty(t, headers.Get(config.SessionIDHeaderName), "non-inject-target requests must not carry session ID header") - assert.Empty(t, header.Get(config.SequenceNumberHeaderName), + assert.Empty(t, headers.Get(config.SequenceNumberHeaderName), "non-inject-target requests must not carry sequence number header") } @@ -295,24 +297,27 @@ func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { require.Equal(t, 2, inject.requestCount(), "inject-target backend should have received exactly two requests") - firstInjectHeader, err := inject.headersAt(0) + firstInjectHeaders, err := inject.headersAt(0) require.NoError(t, err) - assert.Equal(t, sessionID, firstInjectHeader.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", firstInjectHeader.Get(config.SequenceNumberHeaderName), + require.NotNil(t, firstInjectHeaders) + assert.Equal(t, sessionID, firstInjectHeaders.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", firstInjectHeaders.Get(config.SequenceNumberHeaderName), "first inject-target request must have sequence 0") - secondInjectHeader, err := inject.headersAt(1) + secondInjectHeaders, err := inject.headersAt(1) require.NoError(t, err) - assert.Equal(t, sessionID, secondInjectHeader.Get(config.SessionIDHeaderName)) - assert.Equal(t, "3", secondInjectHeader.Get(config.SequenceNumberHeaderName), + require.NotNil(t, secondInjectHeaders) + assert.Equal(t, sessionID, secondInjectHeaders.Get(config.SessionIDHeaderName)) + assert.Equal(t, "3", secondInjectHeaders.Get(config.SequenceNumberHeaderName), "second inject-target request must have sequence 3") // Then: the non-inject-target backend receives no correlation headers. require.Equal(t, 1, other.requestCount()) - otherHeader, err := other.headersAt(0) + otherHeaders, err := other.headersAt(0) require.NoError(t, err) - assert.Empty(t, otherHeader.Get(config.SessionIDHeaderName)) - assert.Empty(t, otherHeader.Get(config.SequenceNumberHeaderName)) + require.NotNil(t, otherHeaders) + assert.Empty(t, otherHeaders.Get(config.SessionIDHeaderName)) + assert.Empty(t, otherHeaders.Get(config.SequenceNumberHeaderName)) // Then: the gap between inject-target sequence numbers (0 and 3) // reveals 2 intermediate events (non-inject-target allowed and denied). @@ -373,12 +378,14 @@ func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { // Then: the inject-target headers show a gap from sequence 0 to 4. require.Equal(t, 2, inject.requestCount()) - firstHeader, err := inject.headersAt(0) + firstHeaders, err := inject.headersAt(0) require.NoError(t, err) - assert.Equal(t, "0", firstHeader.Get(config.SequenceNumberHeaderName)) - secondHeader, err := inject.headersAt(1) + require.NotNil(t, firstHeaders) + assert.Equal(t, "0", firstHeaders.Get(config.SequenceNumberHeaderName)) + secondHeaders, err := inject.headersAt(1) require.NoError(t, err) - assert.Equal(t, "4", secondHeader.Get(config.SequenceNumberHeaderName)) + require.NotNil(t, secondHeaders) + assert.Equal(t, "4", secondHeaders.Get(config.SequenceNumberHeaderName)) // Then: the gap of 3 matches the three intermediate tool-use requests. events := aud.getRequests() @@ -420,17 +427,18 @@ func TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence(t *testing.T) // Then: the backend receives the real values, not the spoofed ones. require.Equal(t, 1, s.injectBackend.requestCount()) - header, err := s.injectBackend.headersAt(0) + headers, err := s.injectBackend.headersAt(0) require.NoError(t, err) - assert.Equal(t, sessionID, header.Get(config.SessionIDHeaderName)) - assert.Equal(t, "0", header.Get(config.SequenceNumberHeaderName)) + require.NotNil(t, headers) + assert.Equal(t, sessionID, headers.Get(config.SessionIDHeaderName)) + assert.Equal(t, "0", headers.Get(config.SequenceNumberHeaderName)) // Then: the audit event agrees with the forwarded header. events := s.auditor.getRequests() require.Len(t, events, 1) assert.Equal(t, strconv.Itoa(int(events[0].SequenceNumber)), - header.Get(config.SequenceNumberHeaderName), + headers.Get(config.SequenceNumberHeaderName), ) } @@ -467,11 +475,12 @@ func TestIntegration_DisabledCorrelationNoHeaders(t *testing.T) { // Then: no correlation headers are injected on the forwarded request. require.Equal(t, 1, backend.requestCount()) - header, err := backend.headersAt(0) + headers, err := backend.headersAt(0) require.NoError(t, err) - assert.Empty(t, header.Get(config.SessionIDHeaderName), + require.NotNil(t, headers) + assert.Empty(t, headers.Get(config.SessionIDHeaderName), "session ID header must not be injected when correlation is disabled") - assert.Empty(t, header.Get(config.SequenceNumberHeaderName), + assert.Empty(t, headers.Get(config.SequenceNumberHeaderName), "sequence number header must not be injected when correlation is disabled") // Then: the request is still audited. @@ -531,9 +540,10 @@ func TestIntegration_ConcurrentRequestsUniqueSequenceNumbers(t *testing.T) { require.Equal(t, numRequests, s.injectBackend.requestCount()) headerSeqs := make(map[string]bool, numRequests) for i := 0; i < numRequests; i++ { - header, err := s.injectBackend.headersAt(i) + headers, err := s.injectBackend.headersAt(i) require.NoError(t, err) - seqStr := header.Get(config.SequenceNumberHeaderName) + require.NotNil(t, headers) + seqStr := headers.Get(config.SequenceNumberHeaderName) assert.NotEmpty(t, seqStr, "request %d: sequence header must be set", i) headerSeqs[seqStr] = true } From 845c5ce177b400a3f044367264cbbd89cc362c2f Mon Sep 17 00:00:00 2001 From: Sas Swart Date: Thu, 21 May 2026 12:55:52 +0000 Subject: [PATCH 15/15] refactor(proxy): rename and enhance sequence gap integration test --- ...xy_session_correlation_integration_test.go | 193 +++++------------- 1 file changed, 52 insertions(+), 141 deletions(-) diff --git a/proxy/proxy_session_correlation_integration_test.go b/proxy/proxy_session_correlation_integration_test.go index 4d4fd07..95bbc23 100644 --- a/proxy/proxy_session_correlation_integration_test.go +++ b/proxy/proxy_session_correlation_integration_test.go @@ -233,175 +233,86 @@ func TestIntegration_DeniedRequestAuditedNeverForwarded(t *testing.T) { "denied requests must not be forwarded to the backend") } -// TestIntegration_MixedRequestsSequenceOrdering sends a realistic -// sequence of LLM, non-LLM, and denied requests, then verifies: -// 1. Sequence numbers increase monotonically across all request types. +// TestIntegration_SequenceGapAcrossMixedRequests sends two inject-target +// requests bookending three allowed tool-use requests and one denied +// request, then verifies: +// 1. Sequence numbers increase monotonically across all six events. // 2. Only inject-target requests carry correlation headers. -// 3. The sequence numbers in headers match the audit events. -// 4. The gap between two LLM requests' sequence numbers reveals the -// intermediate non-LLM and denied activity. -func TestIntegration_MixedRequestsSequenceOrdering(t *testing.T) { - const sessionID = "mixed-test-session" +// 3. The session ID and sequence number in headers match the audit events. +// 4. The gap of 4 between the two inject-target sequence numbers (0 and 5) +// precisely accounts for the three allowed tool-use requests and the +// one denied request in between. +func TestIntegration_SequenceGapAcrossMixedRequests(t *testing.T) { + const sessionID = "mixed-session" - // Given: a proxy with an inject-target backend and a non-inject-target backend. - inject := newMultiRequestCapturingBackend() - defer inject.close() - - other := newMultiRequestCapturingBackend() - defer other.close() - - injectURL, err := url.Parse(inject.server.URL) - require.NoError(t, err) + // Given: a proxy with an inject-target and a non-inject-target backend. + s := newCorrelationTestEnv(t, sessionID) + defer s.stop() - otherURL, err := url.Parse(other.server.URL) - require.NoError(t, err) + // When: a request is sent to the inject-target backend. + s.pt.ExpectGetViaProxy(s.injectBackend.server.URL+"/v1/messages", http.StatusOK) - aud := &capturingAuditor{} + for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { + // When: a request is sent to the non-inject-target backend. + s.pt.ExpectGetViaProxy(s.otherBackend.server.URL+p, http.StatusOK) + } - pt := NewProxyTest(t, - WithCertManager(t.TempDir()), - WithAllowedDomain(injectURL.Hostname()), - WithAllowedDomain(otherURL.Hostname()), - WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, - }), - WithSessionID(sessionID), - WithAuditor(aud), - ).Start() - defer pt.Stop() + // When: a request is sent to a domain that is not allowed. + s.pt.ExpectGetViaProxy("http://evil.example.com/exfil", http.StatusForbidden) - // When: an inject-target, non-inject-target, denied, and inject-target - // request are sent in sequence. - pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) - pt.ExpectGetViaProxy(other.server.URL+"/coder/coder", http.StatusOK) - pt.ExpectGetViaProxy("http://evil.example.com/exfil", http.StatusForbidden) - pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) + // When: a request is sent to the inject-target backend. + s.pt.ExpectGetViaProxy(s.injectBackend.server.URL+"/v1/messages", http.StatusOK) - // Then: all four requests produce audit events with monotonically - // increasing sequence numbers. - events := aud.getRequests() - require.Len(t, events, 4, "expected exactly four audit events") + // Then: all six events are audited with monotonically increasing + // sequence numbers and correct allowed flags. + events := s.auditor.getRequests() + require.Len(t, events, 6, "expected exactly six audit events") - expectedSeq := []int32{0, 1, 2, 3} - expectedAllowed := []bool{true, true, false, true} + expectedAllowed := []bool{true, true, true, true, false, true} for i, ev := range events { - assert.Equal(t, expectedSeq[i], ev.SequenceNumber, + assert.Equal(t, int32(i), ev.SequenceNumber, "event %d: wrong sequence number", i) assert.Equal(t, expectedAllowed[i], ev.Allowed, "event %d: wrong allowed flag", i) } // Then: the inject-target backend receives correlation headers with - // the correct sequence numbers. - require.Equal(t, 2, inject.requestCount(), + // the correct session ID and sequence numbers. + require.Equal(t, 2, s.injectBackend.requestCount(), "inject-target backend should have received exactly two requests") - firstInjectHeaders, err := inject.headersAt(0) + firstInjectHeaders, err := s.injectBackend.headersAt(0) require.NoError(t, err) require.NotNil(t, firstInjectHeaders) assert.Equal(t, sessionID, firstInjectHeaders.Get(config.SessionIDHeaderName)) assert.Equal(t, "0", firstInjectHeaders.Get(config.SequenceNumberHeaderName), "first inject-target request must have sequence 0") - secondInjectHeaders, err := inject.headersAt(1) + secondInjectHeaders, err := s.injectBackend.headersAt(1) require.NoError(t, err) require.NotNil(t, secondInjectHeaders) assert.Equal(t, sessionID, secondInjectHeaders.Get(config.SessionIDHeaderName)) - assert.Equal(t, "3", secondInjectHeaders.Get(config.SequenceNumberHeaderName), - "second inject-target request must have sequence 3") - - // Then: the non-inject-target backend receives no correlation headers. - require.Equal(t, 1, other.requestCount()) - otherHeaders, err := other.headersAt(0) - require.NoError(t, err) - require.NotNil(t, otherHeaders) - assert.Empty(t, otherHeaders.Get(config.SessionIDHeaderName)) - assert.Empty(t, otherHeaders.Get(config.SequenceNumberHeaderName)) - - // Then: the gap between inject-target sequence numbers (0 and 3) - // reveals 2 intermediate events (non-inject-target allowed and denied). - firstInjectSeq := events[0].SequenceNumber - secondInjectSeq := events[3].SequenceNumber - gap := secondInjectSeq - firstInjectSeq - 1 - assert.Equal(t, int32(2), gap, - "gap between inject-target requests should reveal 2 intermediate events") -} - -// TestIntegration_SequenceGapRevealsAgenticLoop sends two inject-target -// requests with several non-inject-target requests in between, -// simulating an agentic loop where the model triggers tool-use HTTP -// calls between prompts. The test verifies that the gap in -// inject-target sequence numbers precisely reflects the count of -// intermediate boundary events. -func TestIntegration_SequenceGapRevealsAgenticLoop(t *testing.T) { - const sessionID = "agentic-loop-session" - - // Given: a proxy with an inject-target and a non-inject-target backend. - inject := newMultiRequestCapturingBackend() - defer inject.close() - - other := newMultiRequestCapturingBackend() - defer other.close() - - injectURL, err := url.Parse(inject.server.URL) - require.NoError(t, err) - - otherURL, err := url.Parse(other.server.URL) - require.NoError(t, err) - - aud := &capturingAuditor{} - - pt := NewProxyTest(t, - WithCertManager(t.TempDir()), - WithAllowedDomain(injectURL.Hostname()), - WithAllowedDomain(otherURL.Hostname()), - WithSessionCorrelation(config.SessionCorrelationConfig{ - Enabled: true, - InjectTargets: []string{"domain=" + injectURL.Hostname() + " path=/v1/*"}, - }), - WithSessionID(sessionID), - WithAuditor(aud), - ).Start() - defer pt.Stop() - - // When: an inject-target request, three tool-use requests to the - // non-inject-target backend, and another inject-target request are - // sent in sequence. - pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) - - for _, p := range []string{"/coder/coder", "/coder/coder/issues", "/coder/coder/pulls"} { - pt.ExpectGetViaProxy(other.server.URL+p, http.StatusOK) + assert.Equal(t, "5", secondInjectHeaders.Get(config.SequenceNumberHeaderName), + "second inject-target request must have sequence 5") + + // Then: the non-inject-target backend receives no correlation headers + // on any of its three requests. + require.Equal(t, 3, s.otherBackend.requestCount()) + for i := 0; i < 3; i++ { + h, err := s.otherBackend.headersAt(i) + require.NoError(t, err) + require.NotNil(t, h) + assert.Empty(t, h.Get(config.SessionIDHeaderName), + "other backend request %d must not carry session ID header", i) + assert.Empty(t, h.Get(config.SequenceNumberHeaderName), + "other backend request %d must not carry sequence number header", i) } - pt.ExpectGetViaProxy(inject.server.URL+"/v1/messages", http.StatusOK) - - // Then: the inject-target headers show a gap from sequence 0 to 4. - require.Equal(t, 2, inject.requestCount()) - firstHeaders, err := inject.headersAt(0) - require.NoError(t, err) - require.NotNil(t, firstHeaders) - assert.Equal(t, "0", firstHeaders.Get(config.SequenceNumberHeaderName)) - secondHeaders, err := inject.headersAt(1) - require.NoError(t, err) - require.NotNil(t, secondHeaders) - assert.Equal(t, "4", secondHeaders.Get(config.SequenceNumberHeaderName)) - - // Then: the gap of 3 matches the three intermediate tool-use requests. - events := aud.getRequests() - require.Len(t, events, 5) - - firstInjectSeq := events[0].SequenceNumber - secondInjectSeq := events[4].SequenceNumber - gap := secondInjectSeq - firstInjectSeq - 1 - assert.Equal(t, int32(3), gap, - "gap between prompts should equal number of tool-use requests") - - // Then: the intermediate audit events correspond to the tool-use requests. - for i := 1; i <= 3; i++ { - assert.Equal(t, int32(i), events[i].SequenceNumber) - assert.True(t, events[i].Allowed) - } + // Then: the gap of 4 between inject-target sequence numbers (0 and 5) + // accounts for the three tool-use requests and the one denied request. + gap := events[5].SequenceNumber - events[0].SequenceNumber - 1 + assert.Equal(t, int32(4), gap, + "gap between inject-target requests should reveal 4 intermediate events") } // TestIntegration_SpoofedHeadersOverwrittenWithCorrectSequence