Skip to content

Commit 373cce1

Browse files
committed
feat(mcp): add input sanitization and test coverage
MCP-SAN (#49): Input sanitization for the MCP server layer. - Add sanitize package: Content (Markdown structure injection), Reflect (truncate + strip control chars for error messages), SessionID (path-safe session identifiers), StripControl, Truncate - Sanitize all reflected user inputs in dispatch error messages (tool names, prompt names, resource URIs) via sanitize.Reflect - Reject unknown entry types before writing to .context/ files - Enforce MaxContentLen (32KB) on entry content in extract.EntryArgs - Sanitize entry content and optional fields via sanitize.Content and extract.SanitizedOpts before writing - Cap journal source limit to MaxSourceLimit (100) - Sanitize caller identifiers in session events - Add input length constants to config/mcp/cfg - Add error message keys for input-too-long and unknown-entry-type MCP-COV (#50): Comprehensive test coverage for MCP subsystem. - internal/mcp/proto: 22 schema round-trip and edge-case tests - internal/mcp/session: 7 state lifecycle tests (100% coverage) - internal/mcp/server: 4 integration tests (Serve edge cases, prompt add-learning) - internal/mcp/server/def/tool: 9 tool definition tests - internal/mcp/server/def/prompt: 9 prompt definition tests - internal/mcp/server/extract: 7 extraction and sanitization tests - internal/mcp/server/io: 3 WriteJSON tests (100% coverage) - internal/mcp/server/out: 8 response builder tests (100% coverage) - internal/mcp/server/parse: 3 request parsing tests (100% coverage) - internal/mcp/server/stat: 2 statistics tests (100% coverage) - internal/sanitize: 22 sanitization tests (Content, Reflect, SessionID, StripControl, Truncate + existing Filename) - Server package coverage: 73% -> 92% Closes #49 Closes #50 Signed-off-by: CoderMungan <codermungan@gmail.com>
1 parent 373f4eb commit 373cce1

23 files changed

Lines changed: 1690 additions & 7 deletions

File tree

internal/assets/commands/text/mcp.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,10 @@ mcp.err-unknown-prompt:
346346
short: 'unknown prompt: %s'
347347
mcp.err-uri-required:
348348
short: uri is required
349+
mcp.err-input-too-long:
350+
short: '%s exceeds maximum length (%d bytes)'
351+
mcp.err-unknown-entry-type:
352+
short: 'unknown entry type: %s'
349353
mcp.format-watch-completed:
350354
short: 'Completed: %s'
351355
mcp.format-wrote:

internal/config/embed/text/mcp_err.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,6 @@ const (
1818
DescKeyMCPErrQueryRequired = "mcp.err-query-required"
1919
DescKeyMCPErrUnknownPrompt = "mcp.err-unknown-prompt"
2020
DescKeyMCPErrURIRequired = "mcp.err-uri-required"
21+
DescKeyMCPErrInputTooLong = "mcp.err-input-too-long"
22+
DescKeyMCPErrUnknownEntryType = "mcp.err-unknown-entry-type"
2123
)

internal/config/mcp/cfg/config.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,23 @@ const (
1212

1313
// DefaultSourceLimit is the max sessions returned by ctx_journal_source.
1414
DefaultSourceLimit = 5
15+
// MaxSourceLimit caps the source limit to prevent unbounded queries.
16+
MaxSourceLimit = 100
1517
// MinWordLen is the shortest word considered for overlap matching.
1618
MinWordLen = 4
1719
// MinWordOverlap is the minimum word matches to signal task completion.
1820
MinWordOverlap = 2
21+
22+
// --- Input length limits (MCP-SAN.1) ---
23+
24+
// MaxContentLen is the maximum byte length for entry content fields.
25+
MaxContentLen = 32_000
26+
// MaxNameLen is the maximum byte length for tool/prompt/resource names.
27+
MaxNameLen = 256
28+
// MaxQueryLen is the maximum byte length for search queries.
29+
MaxQueryLen = 1_000
30+
// MaxCallerLen is the maximum byte length for caller identifiers.
31+
MaxCallerLen = 128
32+
// MaxURILen is the maximum byte length for resource URIs.
33+
MaxURILen = 512
1934
)

internal/mcp/proto/schema_test.go

Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
// / ctx: https://ctx.ist
2+
// ,'`./ do you remember?
3+
// `.,'\
4+
// \ Copyright 2026-present Context contributors.
5+
// SPDX-License-Identifier: Apache-2.0
6+
7+
package proto
8+
9+
import (
10+
"encoding/json"
11+
"testing"
12+
)
13+
14+
func roundTrip(t *testing.T, v interface{}, dst interface{}) {
15+
t.Helper()
16+
data, err := json.Marshal(v)
17+
if err != nil {
18+
t.Fatalf("marshal: %v", err)
19+
}
20+
if err := json.Unmarshal(data, dst); err != nil {
21+
t.Fatalf("unmarshal: %v", err)
22+
}
23+
}
24+
25+
func TestRequestRoundTrip(t *testing.T) {
26+
orig := Request{
27+
JSONRPC: "2.0",
28+
ID: json.RawMessage(`1`),
29+
Method: "tools/call",
30+
Params: json.RawMessage(`{"name":"ctx_status"}`),
31+
}
32+
var got Request
33+
roundTrip(t, orig, &got)
34+
if got.JSONRPC != orig.JSONRPC {
35+
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, orig.JSONRPC)
36+
}
37+
if got.Method != orig.Method {
38+
t.Errorf("Method = %q, want %q", got.Method, orig.Method)
39+
}
40+
if string(got.ID) != string(orig.ID) {
41+
t.Errorf("ID = %s, want %s", got.ID, orig.ID)
42+
}
43+
}
44+
45+
func TestResponseSuccessRoundTrip(t *testing.T) {
46+
orig := Response{
47+
JSONRPC: "2.0",
48+
ID: json.RawMessage(`1`),
49+
Result: map[string]string{"key": "value"},
50+
}
51+
var got Response
52+
roundTrip(t, orig, &got)
53+
if got.JSONRPC != "2.0" {
54+
t.Errorf("JSONRPC = %q, want %q", got.JSONRPC, "2.0")
55+
}
56+
if got.Error != nil {
57+
t.Errorf("unexpected error: %v", got.Error)
58+
}
59+
}
60+
61+
func TestResponseErrorRoundTrip(t *testing.T) {
62+
orig := Response{
63+
JSONRPC: "2.0",
64+
ID: json.RawMessage(`1`),
65+
Error: &RPCError{
66+
Code: ErrCodeNotFound,
67+
Message: "method not found",
68+
},
69+
}
70+
var got Response
71+
roundTrip(t, orig, &got)
72+
if got.Error == nil {
73+
t.Fatal("expected error in response")
74+
}
75+
if got.Error.Code != ErrCodeNotFound {
76+
t.Errorf("Code = %d, want %d", got.Error.Code, ErrCodeNotFound)
77+
}
78+
if got.Error.Message != "method not found" {
79+
t.Errorf("Message = %q, want %q", got.Error.Message, "method not found")
80+
}
81+
}
82+
83+
func TestNotificationRoundTrip(t *testing.T) {
84+
orig := Notification{
85+
JSONRPC: "2.0",
86+
Method: "notifications/initialized",
87+
}
88+
var got Notification
89+
roundTrip(t, orig, &got)
90+
if got.Method != "notifications/initialized" {
91+
t.Errorf("Method = %q, want %q", got.Method, "notifications/initialized")
92+
}
93+
}
94+
95+
func TestRPCErrorWithData(t *testing.T) {
96+
orig := RPCError{
97+
Code: ErrCodeInvalidArg,
98+
Message: "invalid",
99+
Data: map[string]string{"field": "name"},
100+
}
101+
var got RPCError
102+
roundTrip(t, orig, &got)
103+
if got.Code != ErrCodeInvalidArg {
104+
t.Errorf("Code = %d, want %d", got.Code, ErrCodeInvalidArg)
105+
}
106+
}
107+
108+
func TestInitializeParamsRoundTrip(t *testing.T) {
109+
orig := InitializeParams{
110+
ProtocolVersion: ProtocolVersion,
111+
ClientInfo: AppInfo{Name: "test-client", Version: "1.0.0"},
112+
}
113+
var got InitializeParams
114+
roundTrip(t, orig, &got)
115+
if got.ProtocolVersion != ProtocolVersion {
116+
t.Errorf("ProtocolVersion = %q, want %q", got.ProtocolVersion, ProtocolVersion)
117+
}
118+
if got.ClientInfo.Name != "test-client" {
119+
t.Errorf("ClientInfo.Name = %q, want %q", got.ClientInfo.Name, "test-client")
120+
}
121+
}
122+
123+
func TestInitializeResultRoundTrip(t *testing.T) {
124+
orig := InitializeResult{
125+
ProtocolVersion: ProtocolVersion,
126+
Capabilities: ServerCaps{
127+
Resources: &ResourcesCap{Subscribe: true, ListChanged: true},
128+
Tools: &ToolsCap{ListChanged: true},
129+
Prompts: &PromptsCap{ListChanged: false},
130+
},
131+
ServerInfo: AppInfo{Name: "ctx", Version: "0.3.0"},
132+
}
133+
var got InitializeResult
134+
roundTrip(t, orig, &got)
135+
if got.Capabilities.Resources == nil {
136+
t.Fatal("expected Resources capability")
137+
}
138+
if !got.Capabilities.Resources.Subscribe {
139+
t.Error("expected Subscribe = true")
140+
}
141+
}
142+
143+
func TestResourceRoundTrip(t *testing.T) {
144+
orig := Resource{
145+
URI: "ctx://context/tasks",
146+
Name: "tasks",
147+
MimeType: "text/markdown",
148+
}
149+
var got Resource
150+
roundTrip(t, orig, &got)
151+
if got.URI != orig.URI {
152+
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
153+
}
154+
}
155+
156+
func TestToolRoundTrip(t *testing.T) {
157+
orig := Tool{
158+
Name: "ctx_status",
159+
InputSchema: InputSchema{
160+
Type: "object",
161+
Properties: map[string]Property{
162+
"verbose": {Type: "boolean", Description: "Verbose"},
163+
},
164+
Required: []string{"verbose"},
165+
},
166+
Annotations: &ToolAnnotations{ReadOnlyHint: true},
167+
}
168+
var got Tool
169+
roundTrip(t, orig, &got)
170+
if got.Name != "ctx_status" {
171+
t.Errorf("Name = %q, want %q", got.Name, "ctx_status")
172+
}
173+
if got.Annotations == nil || !got.Annotations.ReadOnlyHint {
174+
t.Error("expected ReadOnlyHint = true")
175+
}
176+
}
177+
178+
func TestCallToolParamsRoundTrip(t *testing.T) {
179+
orig := CallToolParams{
180+
Name: "ctx_add",
181+
Arguments: map[string]interface{}{"type": "task", "content": "Test"},
182+
}
183+
var got CallToolParams
184+
roundTrip(t, orig, &got)
185+
if got.Name != "ctx_add" {
186+
t.Errorf("Name = %q, want %q", got.Name, "ctx_add")
187+
}
188+
}
189+
190+
func TestCallToolResultRoundTrip(t *testing.T) {
191+
orig := CallToolResult{
192+
Content: []ToolContent{{Type: "text", Text: "Done"}},
193+
}
194+
var got CallToolResult
195+
roundTrip(t, orig, &got)
196+
if len(got.Content) != 1 {
197+
t.Fatalf("Content count = %d, want 1", len(got.Content))
198+
}
199+
if got.Content[0].Text != "Done" {
200+
t.Errorf("Text = %q, want %q", got.Content[0].Text, "Done")
201+
}
202+
if got.IsError {
203+
t.Error("expected IsError = false")
204+
}
205+
}
206+
207+
func TestCallToolResultErrorRoundTrip(t *testing.T) {
208+
orig := CallToolResult{
209+
Content: []ToolContent{{Type: "text", Text: "failed"}},
210+
IsError: true,
211+
}
212+
var got CallToolResult
213+
roundTrip(t, orig, &got)
214+
if !got.IsError {
215+
t.Error("expected IsError = true")
216+
}
217+
}
218+
219+
func TestPromptRoundTrip(t *testing.T) {
220+
orig := Prompt{
221+
Name: "ctx-session-start",
222+
Arguments: []PromptArgument{
223+
{Name: "content", Required: true},
224+
},
225+
}
226+
var got Prompt
227+
roundTrip(t, orig, &got)
228+
if got.Name != "ctx-session-start" {
229+
t.Errorf("Name = %q, want %q", got.Name, "ctx-session-start")
230+
}
231+
if len(got.Arguments) != 1 || !got.Arguments[0].Required {
232+
t.Error("expected 1 required argument")
233+
}
234+
}
235+
236+
func TestGetPromptResultRoundTrip(t *testing.T) {
237+
orig := GetPromptResult{
238+
Description: "Test",
239+
Messages: []PromptMessage{
240+
{Role: "user", Content: ToolContent{Type: "text", Text: "Hi"}},
241+
},
242+
}
243+
var got GetPromptResult
244+
roundTrip(t, orig, &got)
245+
if len(got.Messages) != 1 {
246+
t.Fatalf("Messages count = %d, want 1", len(got.Messages))
247+
}
248+
if got.Messages[0].Role != "user" {
249+
t.Errorf("Role = %q, want %q", got.Messages[0].Role, "user")
250+
}
251+
}
252+
253+
func TestSubscribeParamsRoundTrip(t *testing.T) {
254+
orig := SubscribeParams{URI: "ctx://context/tasks"}
255+
var got SubscribeParams
256+
roundTrip(t, orig, &got)
257+
if got.URI != orig.URI {
258+
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
259+
}
260+
}
261+
262+
func TestUnsubscribeParamsRoundTrip(t *testing.T) {
263+
orig := UnsubscribeParams{URI: "ctx://context/decisions"}
264+
var got UnsubscribeParams
265+
roundTrip(t, orig, &got)
266+
if got.URI != orig.URI {
267+
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
268+
}
269+
}
270+
271+
func TestResourceUpdatedParamsRoundTrip(t *testing.T) {
272+
orig := ResourceUpdatedParams{URI: "ctx://context/tasks"}
273+
var got ResourceUpdatedParams
274+
roundTrip(t, orig, &got)
275+
if got.URI != orig.URI {
276+
t.Errorf("URI = %q, want %q", got.URI, orig.URI)
277+
}
278+
}
279+
280+
func TestErrorCodeConstants(t *testing.T) {
281+
if ErrCodeParse != -32700 {
282+
t.Errorf("ErrCodeParse = %d, want -32700", ErrCodeParse)
283+
}
284+
if ErrCodeNotFound != -32601 {
285+
t.Errorf("ErrCodeNotFound = %d, want -32601", ErrCodeNotFound)
286+
}
287+
if ErrCodeInvalidArg != -32602 {
288+
t.Errorf("ErrCodeInvalidArg = %d, want -32602", ErrCodeInvalidArg)
289+
}
290+
if ErrCodeInternal != -32603 {
291+
t.Errorf("ErrCodeInternal = %d, want -32603", ErrCodeInternal)
292+
}
293+
}
294+
295+
func TestProtocolVersionValue(t *testing.T) {
296+
if ProtocolVersion != "2024-11-05" {
297+
t.Errorf("ProtocolVersion = %q, want %q", ProtocolVersion, "2024-11-05")
298+
}
299+
}
300+
301+
func TestRequestNilParams(t *testing.T) {
302+
orig := Request{
303+
JSONRPC: "2.0",
304+
ID: json.RawMessage(`"abc"`),
305+
Method: "ping",
306+
}
307+
data, err := json.Marshal(orig)
308+
if err != nil {
309+
t.Fatalf("marshal: %v", err)
310+
}
311+
var got Request
312+
if err := json.Unmarshal(data, &got); err != nil {
313+
t.Fatalf("unmarshal: %v", err)
314+
}
315+
if got.Params != nil {
316+
t.Errorf("expected nil Params, got %s", got.Params)
317+
}
318+
}
319+
320+
func TestResponseNilID(t *testing.T) {
321+
orig := Response{
322+
JSONRPC: "2.0",
323+
Error: &RPCError{Code: ErrCodeParse, Message: "parse error"},
324+
}
325+
var got Response
326+
roundTrip(t, orig, &got)
327+
if got.ID != nil {
328+
t.Errorf("expected nil ID, got %s", got.ID)
329+
}
330+
}
331+
332+
func TestPropertyEnumRoundTrip(t *testing.T) {
333+
orig := Property{
334+
Type: "string",
335+
Enum: []string{"task", "decision", "learning"},
336+
}
337+
var got Property
338+
roundTrip(t, orig, &got)
339+
if len(got.Enum) != 3 {
340+
t.Fatalf("Enum count = %d, want 3", len(got.Enum))
341+
}
342+
if got.Enum[0] != "task" {
343+
t.Errorf("Enum[0] = %q, want %q", got.Enum[0], "task")
344+
}
345+
}

0 commit comments

Comments
 (0)