diff --git a/internal/assets/commands/text/mcp.yaml b/internal/assets/commands/text/mcp.yaml index 7df4e460b..2e9341966 100644 --- a/internal/assets/commands/text/mcp.yaml +++ b/internal/assets/commands/text/mcp.yaml @@ -350,6 +350,10 @@ mcp.err-unknown-prompt: short: 'unknown prompt: %s' mcp.err-uri-required: short: uri is required +mcp.err-input-too-long: + short: '%s exceeds maximum length (%d bytes)' +mcp.err-unknown-entry-type: + short: 'unknown entry type: %s' mcp.format-watch-completed: short: 'Completed: %s' mcp.format-wrote: diff --git a/internal/config/embed/text/mcp_err.go b/internal/config/embed/text/mcp_err.go index 7537dc740..082d371d7 100644 --- a/internal/config/embed/text/mcp_err.go +++ b/internal/config/embed/text/mcp_err.go @@ -30,14 +30,22 @@ const ( // DescKeyMCPErrTypeContentRequired is the text key for mcp err type content // required messages. DescKeyMCPErrTypeContentRequired = "mcp.err-type-content-required" - // DescKeyMCPErrQueryRequired is the text key for mcp err query required - // messages. + // DescKeyMCPErrQueryRequired is the text key for mcp err + // query required messages. DescKeyMCPErrQueryRequired = "mcp.err-query-required" - // DescKeyMCPErrSearchRead is the text key for mcp err search read messages. + // DescKeyMCPErrSearchRead is the text key for mcp err + // search read messages. DescKeyMCPErrSearchRead = "mcp.err-search-read" - // DescKeyMCPErrUnknownPrompt is the text key for mcp err unknown prompt - // messages. + // DescKeyMCPErrUnknownPrompt is the text key for mcp err + // unknown prompt messages. DescKeyMCPErrUnknownPrompt = "mcp.err-unknown-prompt" - // DescKeyMCPErrURIRequired is the text key for mcp err uri required messages. + // DescKeyMCPErrURIRequired is the text key for mcp err + // uri required messages. DescKeyMCPErrURIRequired = "mcp.err-uri-required" + // DescKeyMCPErrInputTooLong is the text key for mcp err + // input too long messages. + DescKeyMCPErrInputTooLong = "mcp.err-input-too-long" + // DescKeyMCPErrUnknownEntryType is the text key for mcp + // err unknown entry type messages. + DescKeyMCPErrUnknownEntryType = "mcp.err-unknown-entry-type" ) diff --git a/internal/config/mcp/cfg/config.go b/internal/config/mcp/cfg/config.go index f8a7ea932..33a0d1774 100644 --- a/internal/config/mcp/cfg/config.go +++ b/internal/config/mcp/cfg/config.go @@ -13,8 +13,23 @@ const ( // DefaultSourceLimit is the max sessions returned by ctx_journal_source. DefaultSourceLimit = 5 + // MaxSourceLimit caps the source limit to prevent unbounded queries. + MaxSourceLimit = 100 // MinWordLen is the shortest word considered for overlap matching. MinWordLen = 4 // MinWordOverlap is the minimum word matches to signal task completion. MinWordOverlap = 2 + + // --- Input length limits (MCP-SAN.1) --- + + // MaxContentLen is the maximum byte length for entry content fields. + MaxContentLen = 32_000 + // MaxNameLen is the maximum byte length for tool/prompt/resource names. + MaxNameLen = 256 + // MaxQueryLen is the maximum byte length for search queries. + MaxQueryLen = 1_000 + // MaxCallerLen is the maximum byte length for caller identifiers. + MaxCallerLen = 128 + // MaxURILen is the maximum byte length for resource URIs. + MaxURILen = 512 ) diff --git a/internal/config/regex/sanitize.go b/internal/config/regex/sanitize.go new file mode 100644 index 000000000..6b818e904 --- /dev/null +++ b/internal/config/regex/sanitize.go @@ -0,0 +1,33 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package regex + +import "regexp" + +// SanEntryHeader matches entry headers like "## [2026-" in +// content sanitization (MCP-SAN.3). +var SanEntryHeader = regexp.MustCompile( + `(?m)^##\s+\[\d{4}-`, +) + +// SanTaskCheckbox matches task checkboxes "- [ ]" and +// "- [x]" in content sanitization. +var SanTaskCheckbox = regexp.MustCompile( + `(?m)^-\s+\[[x ]\]`, +) + +// SanConstitutionRule matches constitution rule format +// "- [ ] **Never" in content sanitization. +var SanConstitutionRule = regexp.MustCompile( + `(?m)^-\s+\[[x ]\]\s+\*\*[A-Z]`, +) + +// SanSessionIDUnsafe matches characters not safe for session +// IDs in file paths: anything outside [a-zA-Z0-9._-]. +var SanSessionIDUnsafe = regexp.MustCompile( + `[^a-zA-Z0-9._-]`, +) diff --git a/internal/config/sanitize/doc.go b/internal/config/sanitize/doc.go new file mode 100644 index 000000000..0fe865a28 --- /dev/null +++ b/internal/config/sanitize/doc.go @@ -0,0 +1,13 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +// Package sanitize defines string and length constants used by +// the sanitize layer. +// +// Constants are referenced by internal/sanitize via config/sanitize.*. +// Provides: [NullByte], [DotDot], [ForwardSlash], [Backslash], +// [HyphenReplace], [EscapePrefix], [MaxSessionIDLen]. +package sanitize diff --git a/internal/config/sanitize/sanitize.go b/internal/config/sanitize/sanitize.go new file mode 100644 index 000000000..0bd7fd0ff --- /dev/null +++ b/internal/config/sanitize/sanitize.go @@ -0,0 +1,34 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +// Sanitize-layer string and length constants. +const ( + // NullByte is the null character stripped from untrusted input. + NullByte = "\x00" + + // DotDot is a path traversal sequence. + DotDot = ".." + + // ForwardSlash is the forward slash stripped from session IDs. + ForwardSlash = "/" + + // Backslash is the backslash stripped from session IDs. + Backslash = "\\" + + // HyphenReplace is the replacement character for unsafe + // session ID characters. + HyphenReplace = "-" + + // EscapePrefix is the backslash prefix for escaping Markdown + // structural patterns. + EscapePrefix = `\` + + // MaxSessionIDLen is the maximum byte length for a session + // identifier. + MaxSessionIDLen = 128 +) diff --git a/internal/entity/mcp_session_test.go b/internal/entity/mcp_session_test.go new file mode 100644 index 000000000..7deaa8868 --- /dev/null +++ b/internal/entity/mcp_session_test.go @@ -0,0 +1,98 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package entity + +import ( + "testing" + "time" +) + +func TestNewMCPSession(t *testing.T) { + s := NewMCPSession() + if s.ToolCalls != 0 { + t.Errorf("ToolCalls = %d, want 0", s.ToolCalls) + } + if s.AddsPerformed == nil { + t.Fatal("AddsPerformed should be initialized") + } + if len(s.AddsPerformed) != 0 { + t.Errorf( + "AddsPerformed length = %d, want 0", + len(s.AddsPerformed), + ) + } + if s.SessionStartedAt.IsZero() { + t.Error("SessionStartedAt should be set") + } + if len(s.PendingFlush) != 0 { + t.Errorf( + "PendingFlush length = %d, want 0", + len(s.PendingFlush), + ) + } +} + +func TestRecordToolCall(t *testing.T) { + s := NewMCPSession() + s.RecordToolCall() + if s.ToolCalls != 1 { + t.Errorf("ToolCalls = %d, want 1", s.ToolCalls) + } + s.RecordToolCall() + s.RecordToolCall() + if s.ToolCalls != 3 { + t.Errorf("ToolCalls = %d, want 3", s.ToolCalls) + } +} + +func TestRecordAdd(t *testing.T) { + s := NewMCPSession() + s.RecordAdd("task") + s.RecordAdd("task") + s.RecordAdd("decision") + if s.AddsPerformed["task"] != 2 { + t.Errorf( + "task adds = %d, want 2", + s.AddsPerformed["task"], + ) + } + if s.AddsPerformed["decision"] != 1 { + t.Errorf( + "decision adds = %d, want 1", + s.AddsPerformed["decision"], + ) + } +} + +func TestQueuePendingUpdate(t *testing.T) { + s := NewMCPSession() + now := time.Now() + s.QueuePendingUpdate(PendingUpdate{ + Type: "task", + Content: "Build feature", + QueuedAt: now, + }) + if len(s.PendingFlush) != 1 { + t.Fatalf( + "PendingFlush length = %d, want 1", + len(s.PendingFlush), + ) + } + pu := s.PendingFlush[0] + if pu.Type != "task" { + t.Errorf( + "Type = %q, want %q", + pu.Type, "task", + ) + } + if pu.Content != "Build feature" { + t.Errorf( + "Content = %q, want %q", + pu.Content, "Build feature", + ) + } +} diff --git a/internal/err/mcp/mcp.go b/internal/err/mcp/mcp.go index 7c6f2fb68..5cf3b90f2 100644 --- a/internal/err/mcp/mcp.go +++ b/internal/err/mcp/mcp.go @@ -65,3 +65,19 @@ func UnknownEventType(eventType string) error { eventType, ) } + +// InputTooLong returns an error when input exceeds the allowed +// length. +// +// Parameters: +// - field: the field name that is too long +// - maxLen: the maximum allowed length +// +// Returns: +// - error: " exceeds maximum length of " +func InputTooLong(field string, maxLen int) error { + return fmt.Errorf( + desc.Text(text.DescKeyMCPErrInputTooLong), + field, maxLen, + ) +} diff --git a/internal/mcp/proto/schema_test.go b/internal/mcp/proto/schema_test.go new file mode 100644 index 000000000..dc294c366 --- /dev/null +++ b/internal/mcp/proto/schema_test.go @@ -0,0 +1,418 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package proto_test + +import ( + "encoding/json" + "testing" + + cfgSchema "github.com/ActiveMemory/ctx/internal/config/mcp/schema" + "github.com/ActiveMemory/ctx/internal/mcp/proto" +) + +func roundTrip( + t *testing.T, v interface{}, dst interface{}, +) { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("marshal: %v", err) + } + if err := json.Unmarshal(data, dst); err != nil { + t.Fatalf("unmarshal: %v", err) + } +} + +func TestRequestRoundTrip(t *testing.T) { + orig := proto.Request{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Method: "tools/call", + Params: json.RawMessage(`{"name":"ctx_status"}`), + } + var got proto.Request + roundTrip(t, orig, &got) + if got.JSONRPC != orig.JSONRPC { + t.Errorf("JSONRPC = %q, want %q", + got.JSONRPC, orig.JSONRPC) + } + if got.Method != orig.Method { + t.Errorf("Method = %q, want %q", + got.Method, orig.Method) + } + if string(got.ID) != string(orig.ID) { + t.Errorf("ID = %s, want %s", got.ID, orig.ID) + } +} + +func TestResponseSuccessRoundTrip(t *testing.T) { + orig := proto.Response{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Result: map[string]string{"key": "value"}, + } + var got proto.Response + roundTrip(t, orig, &got) + if got.JSONRPC != "2.0" { + t.Errorf("JSONRPC = %q, want %q", + got.JSONRPC, "2.0") + } + if got.Error != nil { + t.Errorf("unexpected error: %v", got.Error) + } +} + +func TestResponseErrorRoundTrip(t *testing.T) { + orig := proto.Response{ + JSONRPC: "2.0", + ID: json.RawMessage(`1`), + Error: &proto.RPCError{ + Code: cfgSchema.ErrCodeNotFound, + Message: "method not found", + }, + } + var got proto.Response + roundTrip(t, orig, &got) + if got.Error == nil { + t.Fatal("expected error in response") + } + if got.Error.Code != cfgSchema.ErrCodeNotFound { + t.Errorf("Code = %d, want %d", + got.Error.Code, cfgSchema.ErrCodeNotFound) + } + if got.Error.Message != "method not found" { + t.Errorf("Message = %q, want %q", + got.Error.Message, "method not found") + } +} + +func TestNotificationRoundTrip(t *testing.T) { + orig := proto.Notification{ + JSONRPC: "2.0", + Method: "notifications/initialized", + } + var got proto.Notification + roundTrip(t, orig, &got) + if got.Method != "notifications/initialized" { + t.Errorf("Method = %q, want %q", + got.Method, "notifications/initialized") + } +} + +func TestRPCErrorWithData(t *testing.T) { + orig := proto.RPCError{ + Code: cfgSchema.ErrCodeInvalidArg, + Message: "invalid", + Data: map[string]string{"field": "name"}, + } + var got proto.RPCError + roundTrip(t, orig, &got) + if got.Code != cfgSchema.ErrCodeInvalidArg { + t.Errorf("Code = %d, want %d", + got.Code, cfgSchema.ErrCodeInvalidArg) + } +} + +func TestInitializeParamsRoundTrip(t *testing.T) { + orig := proto.InitializeParams{ + ProtocolVersion: cfgSchema.ProtocolVersion, + ClientInfo: proto.AppInfo{ + Name: "test-client", + Version: "1.0.0", + }, + } + var got proto.InitializeParams + roundTrip(t, orig, &got) + if got.ProtocolVersion != cfgSchema.ProtocolVersion { + t.Errorf("ProtocolVersion = %q, want %q", + got.ProtocolVersion, cfgSchema.ProtocolVersion) + } + if got.ClientInfo.Name != "test-client" { + t.Errorf("ClientInfo.Name = %q, want %q", + got.ClientInfo.Name, "test-client") + } +} + +func TestInitializeResultRoundTrip(t *testing.T) { + orig := proto.InitializeResult{ + ProtocolVersion: cfgSchema.ProtocolVersion, + Capabilities: proto.ServerCaps{ + Resources: &proto.ResourcesCap{ + Subscribe: true, + ListChanged: true, + }, + Tools: &proto.ToolsCap{ListChanged: true}, + Prompts: &proto.PromptsCap{ListChanged: false}, + }, + ServerInfo: proto.AppInfo{ + Name: "ctx", + Version: "0.3.0", + }, + } + var got proto.InitializeResult + roundTrip(t, orig, &got) + if got.Capabilities.Resources == nil { + t.Fatal("expected Resources capability") + } + if !got.Capabilities.Resources.Subscribe { + t.Error("expected Subscribe = true") + } +} + +func TestResourceRoundTrip(t *testing.T) { + orig := proto.Resource{ + URI: "ctx://context/tasks", + Name: "tasks", + MimeType: "text/markdown", + } + var got proto.Resource + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", + got.URI, orig.URI) + } +} + +func TestToolRoundTrip(t *testing.T) { + orig := proto.Tool{ + Name: "ctx_status", + InputSchema: proto.InputSchema{ + Type: "object", + Properties: map[string]proto.Property{ + "verbose": { + Type: "boolean", + Description: "Verbose", + }, + }, + Required: []string{"verbose"}, + }, + Annotations: &proto.ToolAnnotations{ + ReadOnlyHint: true, + }, + } + var got proto.Tool + roundTrip(t, orig, &got) + if got.Name != "ctx_status" { + t.Errorf("Name = %q, want %q", + got.Name, "ctx_status") + } + if got.Annotations == nil || + !got.Annotations.ReadOnlyHint { + t.Error("expected ReadOnlyHint = true") + } +} + +func TestCallToolParamsRoundTrip(t *testing.T) { + orig := proto.CallToolParams{ + Name: "ctx_add", + Arguments: map[string]interface{}{ + "type": "task", + "content": "Test", + }, + } + var got proto.CallToolParams + roundTrip(t, orig, &got) + if got.Name != "ctx_add" { + t.Errorf("Name = %q, want %q", + got.Name, "ctx_add") + } +} + +func TestCallToolResultRoundTrip(t *testing.T) { + orig := proto.CallToolResult{ + Content: []proto.ToolContent{ + {Type: "text", Text: "Done"}, + }, + } + var got proto.CallToolResult + roundTrip(t, orig, &got) + if len(got.Content) != 1 { + t.Fatalf("Content count = %d, want 1", + len(got.Content)) + } + if got.Content[0].Text != "Done" { + t.Errorf("Text = %q, want %q", + got.Content[0].Text, "Done") + } + if got.IsError { + t.Error("expected IsError = false") + } +} + +func TestCallToolResultErrorRoundTrip(t *testing.T) { + orig := proto.CallToolResult{ + Content: []proto.ToolContent{ + {Type: "text", Text: "failed"}, + }, + IsError: true, + } + var got proto.CallToolResult + roundTrip(t, orig, &got) + if !got.IsError { + t.Error("expected IsError = true") + } +} + +func TestPromptRoundTrip(t *testing.T) { + orig := proto.Prompt{ + Name: "ctx-session-start", + Arguments: []proto.PromptArgument{ + {Name: "content", Required: true}, + }, + } + var got proto.Prompt + roundTrip(t, orig, &got) + if got.Name != "ctx-session-start" { + t.Errorf("Name = %q, want %q", + got.Name, "ctx-session-start") + } + if len(got.Arguments) != 1 || + !got.Arguments[0].Required { + t.Error("expected 1 required argument") + } +} + +func TestGetPromptResultRoundTrip(t *testing.T) { + orig := proto.GetPromptResult{ + Description: "Test", + Messages: []proto.PromptMessage{ + { + Role: "user", + Content: proto.ToolContent{ + Type: "text", + Text: "Hi", + }, + }, + }, + } + var got proto.GetPromptResult + roundTrip(t, orig, &got) + if len(got.Messages) != 1 { + t.Fatalf("Messages count = %d, want 1", + len(got.Messages)) + } + if got.Messages[0].Role != "user" { + t.Errorf("Role = %q, want %q", + got.Messages[0].Role, "user") + } +} + +func TestSubscribeParamsRoundTrip(t *testing.T) { + orig := proto.SubscribeParams{ + URI: "ctx://context/tasks", + } + var got proto.SubscribeParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", + got.URI, orig.URI) + } +} + +func TestUnsubscribeParamsRoundTrip(t *testing.T) { + orig := proto.UnsubscribeParams{ + URI: "ctx://context/decisions", + } + var got proto.UnsubscribeParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", + got.URI, orig.URI) + } +} + +func TestResourceUpdatedParamsRoundTrip(t *testing.T) { + orig := proto.ResourceUpdatedParams{ + URI: "ctx://context/tasks", + } + var got proto.ResourceUpdatedParams + roundTrip(t, orig, &got) + if got.URI != orig.URI { + t.Errorf("URI = %q, want %q", + got.URI, orig.URI) + } +} + +func TestErrorCodeConstants(t *testing.T) { + if cfgSchema.ErrCodeParse != -32700 { + t.Errorf("ErrCodeParse = %d, want -32700", + cfgSchema.ErrCodeParse) + } + if cfgSchema.ErrCodeNotFound != -32601 { + t.Errorf("ErrCodeNotFound = %d, want -32601", + cfgSchema.ErrCodeNotFound) + } + if cfgSchema.ErrCodeInvalidArg != -32602 { + t.Errorf("ErrCodeInvalidArg = %d, want -32602", + cfgSchema.ErrCodeInvalidArg) + } + if cfgSchema.ErrCodeInternal != -32603 { + t.Errorf("ErrCodeInternal = %d, want -32603", + cfgSchema.ErrCodeInternal) + } +} + +func TestProtocolVersionValue(t *testing.T) { + if cfgSchema.ProtocolVersion != "2024-11-05" { + t.Errorf("ProtocolVersion = %q, want %q", + cfgSchema.ProtocolVersion, "2024-11-05") + } +} + +func TestRequestNilParams(t *testing.T) { + orig := proto.Request{ + JSONRPC: "2.0", + ID: json.RawMessage(`"abc"`), + Method: "ping", + } + data, err := json.Marshal(orig) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var got proto.Request + if err := json.Unmarshal(data, &got); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if got.Params != nil { + t.Errorf("expected nil Params, got %s", + got.Params) + } +} + +func TestResponseNilID(t *testing.T) { + orig := proto.Response{ + JSONRPC: "2.0", + Error: &proto.RPCError{ + Code: cfgSchema.ErrCodeParse, + Message: "parse error", + }, + } + var got proto.Response + roundTrip(t, orig, &got) + if got.ID != nil { + t.Errorf("expected nil ID, got %s", got.ID) + } +} + +func TestPropertyEnumRoundTrip(t *testing.T) { + orig := proto.Property{ + Type: "string", + Enum: []string{ + "task", "decision", "learning", + }, + } + var got proto.Property + roundTrip(t, orig, &got) + if len(got.Enum) != 3 { + t.Fatalf("Enum count = %d, want 3", + len(got.Enum)) + } + if got.Enum[0] != "task" { + t.Errorf("Enum[0] = %q, want %q", + got.Enum[0], "task") + } +} diff --git a/internal/mcp/server/def/prompt/prompt_test.go b/internal/mcp/server/def/prompt/prompt_test.go new file mode 100644 index 000000000..a3cebdee8 --- /dev/null +++ b/internal/mcp/server/def/prompt/prompt_test.go @@ -0,0 +1,150 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package prompt + +import ( + "testing" + + cfgPrompt "github.com/ActiveMemory/ctx/internal/config/mcp/prompt" +) + +func TestDefsCount(t *testing.T) { + if len(Defs) != 5 { + t.Errorf("prompt count = %d, want 5", len(Defs)) + } +} + +func TestDefsNoDuplicateNames(t *testing.T) { + seen := make(map[string]bool) + for _, d := range Defs { + if seen[d.Name] { + t.Errorf("duplicate prompt name: %s", d.Name) + } + seen[d.Name] = true + } +} + +func TestDefsAllNamed(t *testing.T) { + for i, d := range Defs { + if d.Name == "" { + t.Errorf("prompt[%d] has empty name", i) + } + } +} + +func TestDefsContainsAllConfigPrompts(t *testing.T) { + want := []string{ + cfgPrompt.SessionStart, + cfgPrompt.AddDecision, + cfgPrompt.AddLearning, + cfgPrompt.Reflect, + cfgPrompt.Checkpoint, + } + names := make(map[string]bool) + for _, d := range Defs { + names[d.Name] = true + } + for _, w := range want { + if !names[w] { + t.Errorf("missing prompt: %s", w) + } + } +} + +func TestDefsAddDecisionArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.AddDecision { + continue + } + if len(d.Arguments) != 4 { + t.Errorf( + "add-decision argument count = %d, want 4", + len(d.Arguments), + ) + } + for _, a := range d.Arguments { + if !a.Required { + t.Errorf( + "argument %q should be required", a.Name, + ) + } + } + return + } + t.Error("add-decision prompt not found") +} + +func TestDefsAddLearningArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.AddLearning { + continue + } + if len(d.Arguments) != 4 { + t.Errorf( + "add-learning argument count = %d, want 4", + len(d.Arguments), + ) + } + for _, a := range d.Arguments { + if !a.Required { + t.Errorf( + "argument %q should be required", a.Name, + ) + } + } + return + } + t.Error("add-learning prompt not found") +} + +func TestDefsSessionStartNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.SessionStart { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "session-start should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("session-start prompt not found") +} + +func TestDefsReflectNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.Reflect { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "reflect should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("reflect prompt not found") +} + +func TestDefsCheckpointNoArgs(t *testing.T) { + for _, d := range Defs { + if d.Name != cfgPrompt.Checkpoint { + continue + } + if len(d.Arguments) != 0 { + t.Errorf( + "checkpoint should have 0 args, got %d", + len(d.Arguments), + ) + } + return + } + t.Error("checkpoint prompt not found") +} diff --git a/internal/mcp/server/def/tool/tool_test.go b/internal/mcp/server/def/tool/tool_test.go new file mode 100644 index 000000000..14fce6a08 --- /dev/null +++ b/internal/mcp/server/def/tool/tool_test.go @@ -0,0 +1,140 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package tool + +import ( + "testing" + + cfgMcpTool "github.com/ActiveMemory/ctx/internal/config/mcp/tool" + "github.com/ActiveMemory/ctx/internal/mcp/proto" +) + +func TestDefsCount(t *testing.T) { + if len(Defs()) != 15 { + t.Errorf("tool count = %d, want 15", len(Defs())) + } +} + +func TestDefsNoDuplicateNames(t *testing.T) { + seen := make(map[string]bool) + for _, d := range Defs() { + if seen[d.Name] { + t.Errorf("duplicate tool name: %s", d.Name) + } + seen[d.Name] = true + } +} + +func TestDefsAllNamed(t *testing.T) { + for i, d := range Defs() { + if d.Name == "" { + t.Errorf("tool[%d] has empty name", i) + } + } +} + +// Note: Description fields are populated by desc.Text() at package +// init time. They are verified as non-empty in the server integration +// tests where lookup.Init() runs before this package is imported. + +func TestDefsAllHaveObjectSchema(t *testing.T) { + for _, d := range Defs() { + if d.InputSchema.Type != "object" { + t.Errorf( + "tool %q schema type = %q, want %q", + d.Name, d.InputSchema.Type, "object", + ) + } + } +} + +func TestDefsContainsAllConfigTools(t *testing.T) { + want := []string{ + cfgMcpTool.Status, + cfgMcpTool.Add, + cfgMcpTool.Complete, + cfgMcpTool.Drift, + cfgMcpTool.JournalSource, + cfgMcpTool.WatchUpdate, + cfgMcpTool.Compact, + cfgMcpTool.Next, + cfgMcpTool.CheckTaskCompletion, + cfgMcpTool.SessionEvent, + cfgMcpTool.Remind, + cfgMcpTool.SteeringGet, + cfgMcpTool.Search, + cfgMcpTool.SessionStart, + cfgMcpTool.SessionEnd, + } + names := make(map[string]bool) + for _, d := range Defs() { + names[d.Name] = true + } + for _, w := range want { + if !names[w] { + t.Errorf("missing tool: %s", w) + } + } +} + +func TestDefsAnnotations(t *testing.T) { + for _, d := range Defs() { + if d.Annotations == nil { + t.Errorf( + "tool %q has nil annotations", d.Name, + ) + } + } +} + +func TestDefsAddRequiredFields(t *testing.T) { + for _, d := range Defs() { + if d.Name != cfgMcpTool.Add { + continue + } + if len(d.InputSchema.Required) < 2 { + t.Errorf( + "add tool requires at least 2 fields, got %d", + len(d.InputSchema.Required), + ) + } + return + } + t.Error("add tool not found in Defs") +} + +func TestDefsMergeProps(t *testing.T) { + dst := map[string]proto.Property{ + "a": {Type: "string"}, + } + src := map[string]proto.Property{ + "b": {Type: "number"}, + } + result := MergeProps(dst, src) + if len(result) != 2 { + t.Errorf("merged length = %d, want 2", len(result)) + } + if result["b"].Type != "number" { + t.Errorf( + "result[b].Type = %q, want %q", + result["b"].Type, "number", + ) + } +} + +func TestDefsEntryAttrProps(t *testing.T) { + props := EntryAttrProps("test.key") + expected := []string{ + "context", "rationale", "consequence", + "lesson", "application", + } + for _, key := range expected { + if _, ok := props[key]; !ok { + t.Errorf("missing entry attr prop: %s", key) + } + } +} diff --git a/internal/mcp/server/extract/extract.go b/internal/mcp/server/extract/extract.go index a56b4de2b..270b4dad9 100644 --- a/internal/mcp/server/extract/extract.go +++ b/internal/mcp/server/extract/extract.go @@ -8,20 +8,25 @@ package extract import ( "github.com/ActiveMemory/ctx/internal/config/cli" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/field" "github.com/ActiveMemory/ctx/internal/entity" errMcp "github.com/ActiveMemory/ctx/internal/err/mcp" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // EntryArgs extracts required type/content from MCP args. // +// Validates that both fields are present and that content does not +// exceed MaxContentLen. +// // Parameters: // - args: MCP tool arguments // // Returns: // - string: extracted entry type // - string: extracted content string -// - error: non-nil if type or content is missing +// - error: non-nil if type or content is missing, or content too long func EntryArgs( args map[string]interface{}, ) (string, string, error) { @@ -32,6 +37,13 @@ func EntryArgs( return "", "", errMcp.TypeContentRequired() } + // MCP-SAN.1: Enforce input length limits. + if len(content) > cfg.MaxContentLen { + return "", "", errMcp.InputTooLong( + field.Content, cfg.MaxContentLen, + ) + } + return entryType, content, nil } @@ -76,3 +88,24 @@ func Opts(args map[string]interface{}) entity.EntryOpts { } return opts } + +// SanitizedOpts builds EntryOpts with content sanitization applied +// to all text fields. +// +// Parameters: +// - args: MCP tool arguments with optional entry fields +// +// Returns: +// - entity.EntryOpts: sanitized options struct +func SanitizedOpts( + args map[string]interface{}, +) entity.EntryOpts { + opts := Opts(args) + opts.Context = sanitize.Content(opts.Context) + opts.Rationale = sanitize.Content(opts.Rationale) + opts.Consequence = sanitize.Content(opts.Consequence) + opts.Lesson = sanitize.Content(opts.Lesson) + opts.Application = sanitize.Content(opts.Application) + opts.SessionID = sanitize.SessionID(opts.SessionID) + return opts +} diff --git a/internal/mcp/server/extract/extract_test.go b/internal/mcp/server/extract/extract_test.go new file mode 100644 index 000000000..1ac47af8d --- /dev/null +++ b/internal/mcp/server/extract/extract_test.go @@ -0,0 +1,113 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package extract + +import ( + "os" + "strings" + "testing" + + "github.com/ActiveMemory/ctx/internal/assets/read/lookup" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" +) + +func TestMain(m *testing.M) { + lookup.Init() + os.Exit(m.Run()) +} + +func TestEntryArgsValid(t *testing.T) { + args := map[string]interface{}{ + "type": "decision", + "content": "Use Go", + } + typ, content, err := EntryArgs(args) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if typ != "decision" { + t.Errorf("type = %q, want decision", typ) + } + if content != "Use Go" { + t.Errorf("content = %q, want Use Go", content) + } +} + +func TestEntryArgsMissingType(t *testing.T) { + args := map[string]interface{}{"content": "ok"} + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for missing type") + } +} + +func TestEntryArgsMissingContent(t *testing.T) { + args := map[string]interface{}{"type": "decision"} + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for missing content") + } +} + +func TestEntryArgsTooLong(t *testing.T) { + args := map[string]interface{}{ + "type": "decision", + "content": strings.Repeat("x", cfg.MaxContentLen+1), + } + _, _, err := EntryArgs(args) + if err == nil { + t.Fatal("expected error for content too long") + } +} + +func TestOptsAllFields(t *testing.T) { + args := map[string]interface{}{ + "priority": "high", + "context": "ctx", + "rationale": "because", + "consequence": "result", + "lesson": "learned", + "application": "apply", + } + opts := Opts(args) + if opts.Priority != "high" { + t.Errorf("priority = %q", opts.Priority) + } + if opts.Context != "ctx" { + t.Errorf("context = %q", opts.Context) + } + if opts.Rationale != "because" { + t.Errorf("rationale = %q", opts.Rationale) + } + if opts.Consequence != "result" { + t.Errorf("consequence = %q", opts.Consequence) + } + if opts.Lesson != "learned" { + t.Errorf("lesson = %q", opts.Lesson) + } + if opts.Application != "apply" { + t.Errorf("application = %q", opts.Application) + } +} + +func TestOptsEmpty(t *testing.T) { + opts := Opts(map[string]interface{}{}) + if opts.Priority != "" { + t.Error("expected empty priority") + } +} + +func TestSanitizedOpts(t *testing.T) { + args := map[string]interface{}{ + "context": "safe text", + "rationale": "good reason", + } + opts := SanitizedOpts(args) + if opts.Context != "safe text" { + t.Errorf("context = %q", opts.Context) + } +} diff --git a/internal/mcp/server/io/write_test.go b/internal/mcp/server/io/write_test.go new file mode 100644 index 000000000..ba6157de8 --- /dev/null +++ b/internal/mcp/server/io/write_test.go @@ -0,0 +1,48 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package io + +import ( + "bytes" + "os" + "testing" +) + +func TestWriteJSONSuccess(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + err := w.WriteJSON(map[string]int{"a": 1}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := buf.String(); got != "{\"a\":1}\n" { + t.Errorf("output = %q", got) + } +} + +func TestWriteJSONMarshalError(t *testing.T) { + var buf bytes.Buffer + w := NewWriter(&buf) + err := w.WriteJSON(make(chan int)) + if err == nil { + t.Fatal("expected marshal error") + } +} + +type errWriter struct{} + +func (errWriter) Write([]byte) (int, error) { + return 0, os.ErrClosed +} + +func TestWriteJSONWriteError(t *testing.T) { + w := NewWriter(errWriter{}) + err := w.WriteJSON("hello") + if err == nil { + t.Fatal("expected write error") + } +} diff --git a/internal/mcp/server/out/out_test.go b/internal/mcp/server/out/out_test.go new file mode 100644 index 000000000..4e9abff8c --- /dev/null +++ b/internal/mcp/server/out/out_test.go @@ -0,0 +1,114 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package out + +import ( + "encoding/json" + "errors" + "testing" + + cfgSchema "github.com/ActiveMemory/ctx/internal/config/mcp/schema" + "github.com/ActiveMemory/ctx/internal/mcp/proto" +) + +func TestOkResponse(t *testing.T) { + id, _ := json.Marshal(1) + resp := OkResponse(id, map[string]string{"k": "v"}) + if resp.JSONRPC != "2.0" { + t.Errorf("jsonrpc = %q", resp.JSONRPC) + } + if resp.Error != nil { + t.Error("unexpected error field") + } +} + +func TestErrResponse(t *testing.T) { + id, _ := json.Marshal(1) + resp := ErrResponse(id, cfgSchema.ErrCodeInternal, "boom") + if resp.Error == nil { + t.Fatal("expected error") + } + if resp.Error.Code != cfgSchema.ErrCodeInternal { + t.Errorf("code = %d", resp.Error.Code) + } + if resp.Error.Message != "boom" { + t.Errorf("msg = %q", resp.Error.Message) + } +} + +func TestToolOK(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolOK(id, "ok") + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } + if r.Content[0].Text != "ok" { + t.Errorf("text = %q", r.Content[0].Text) + } +} + +func TestToolError(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolError(id, "fail") + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} + +func TestToolResultSuccess(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolResult(id, "done", nil) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } +} + +func TestToolResultError(t *testing.T) { + id, _ := json.Marshal(1) + resp := ToolResult(id, "", errors.New("bad")) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} + +func TestCallSuccess(t *testing.T) { + id, _ := json.Marshal(1) + resp := Call(id, func() (string, error) { + return "ok", nil + }) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if r.IsError { + t.Error("unexpected isError") + } +} + +func TestCallError(t *testing.T) { + id, _ := json.Marshal(1) + resp := Call(id, func() (string, error) { + return "", errors.New("oops") + }) + raw, _ := json.Marshal(resp.Result) + var r proto.CallToolResult + _ = json.Unmarshal(raw, &r) + if !r.IsError { + t.Error("expected isError") + } +} diff --git a/internal/mcp/server/parse/parse_test.go b/internal/mcp/server/parse/parse_test.go new file mode 100644 index 000000000..0532556a0 --- /dev/null +++ b/internal/mcp/server/parse/parse_test.go @@ -0,0 +1,56 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package parse + +import ( + "os" + "testing" + + "github.com/ActiveMemory/ctx/internal/assets/read/lookup" +) + +func TestMain(m *testing.M) { + lookup.Init() + os.Exit(m.Run()) +} + +func TestRequestValid(t *testing.T) { + data := []byte(`{"jsonrpc":"2.0","id":1,"method":"ping"}`) + req, errResp := Request(data) + switch { + case errResp != nil: + t.Fatal("unexpected error response") + case req == nil: + t.Fatal("expected non-nil request") + case req.Method != "ping": + t.Errorf("method = %q, want ping", req.Method) + } +} + +func TestRequestMalformed(t *testing.T) { + req, errResp := Request([]byte(`not-json`)) + if req != nil { + t.Fatal("expected nil request") + } + if errResp == nil || errResp.Error == nil { + t.Fatal("expected error response") + } + if errResp.Error.Code != -32700 { + t.Errorf("code = %d, want -32700", errResp.Error.Code) + } +} + +func TestRequestNotification(t *testing.T) { + data := []byte(`{"jsonrpc":"2.0","method":"notify"}`) + req, errResp := Request(data) + if req != nil { + t.Error("expected nil request for notification") + } + if errResp != nil { + t.Error("expected nil error for notification") + } +} diff --git a/internal/mcp/server/resource/dispatch.go b/internal/mcp/server/resource/dispatch.go index 7221876e7..48f0a5c13 100644 --- a/internal/mcp/server/resource/dispatch.go +++ b/internal/mcp/server/resource/dispatch.go @@ -12,11 +12,13 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" cfgSchema "github.com/ActiveMemory/ctx/internal/config/mcp/schema" "github.com/ActiveMemory/ctx/internal/context/load" "github.com/ActiveMemory/ctx/internal/mcp/proto" "github.com/ActiveMemory/ctx/internal/mcp/server/catalog" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns the pre-built resource list. @@ -78,7 +80,7 @@ func DispatchRead( return out.ErrResponse(req.ID, cfgSchema.ErrCodeInvalidArg, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownResource), - params.URI, + sanitize.Reflect(params.URI, cfg.MaxURILen), )) } diff --git a/internal/mcp/server/route/prompt/dispatch.go b/internal/mcp/server/route/prompt/dispatch.go index 1fc3e4878..3a31470ee 100644 --- a/internal/mcp/server/route/prompt/dispatch.go +++ b/internal/mcp/server/route/prompt/dispatch.go @@ -12,12 +12,14 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/prompt" cfgSchema "github.com/ActiveMemory/ctx/internal/config/mcp/schema" "github.com/ActiveMemory/ctx/internal/entity" "github.com/ActiveMemory/ctx/internal/mcp/proto" defPrompt "github.com/ActiveMemory/ctx/internal/mcp/server/def/prompt" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns all available prompts. @@ -72,7 +74,7 @@ func DispatchGet( req.ID, cfgSchema.ErrCodeNotFound, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownPrompt), - params.Name, + sanitize.Reflect(params.Name, cfg.MaxNameLen), ), ) } diff --git a/internal/mcp/server/route/tool/dispatch.go b/internal/mcp/server/route/tool/dispatch.go index f32fa58e5..54908d556 100644 --- a/internal/mcp/server/route/tool/dispatch.go +++ b/internal/mcp/server/route/tool/dispatch.go @@ -12,6 +12,7 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" cfgSchema "github.com/ActiveMemory/ctx/internal/config/mcp/schema" "github.com/ActiveMemory/ctx/internal/config/mcp/tool" "github.com/ActiveMemory/ctx/internal/entity" @@ -19,6 +20,7 @@ import ( "github.com/ActiveMemory/ctx/internal/mcp/proto" defTool "github.com/ActiveMemory/ctx/internal/mcp/server/def/tool" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // DispatchList returns all available tools. @@ -111,7 +113,7 @@ func DispatchCall( req.ID, cfgSchema.ErrCodeNotFound, fmt.Sprintf( desc.Text(text.DescKeyMCPErrUnknownTool), - params.Name, + sanitize.Reflect(params.Name, cfg.MaxNameLen), ), ) } diff --git a/internal/mcp/server/route/tool/tool.go b/internal/mcp/server/route/tool/tool.go index 6fef5651a..b2057ea7f 100644 --- a/internal/mcp/server/route/tool/tool.go +++ b/internal/mcp/server/route/tool/tool.go @@ -14,6 +14,7 @@ import ( "github.com/ActiveMemory/ctx/internal/assets/read/desc" "github.com/ActiveMemory/ctx/internal/config/cli" "github.com/ActiveMemory/ctx/internal/config/embed/text" + "github.com/ActiveMemory/ctx/internal/config/entry" "github.com/ActiveMemory/ctx/internal/config/mcp/cfg" "github.com/ActiveMemory/ctx/internal/config/mcp/field" cfgTime "github.com/ActiveMemory/ctx/internal/config/time" @@ -22,6 +23,7 @@ import ( "github.com/ActiveMemory/ctx/internal/mcp/proto" "github.com/ActiveMemory/ctx/internal/mcp/server/extract" "github.com/ActiveMemory/ctx/internal/mcp/server/out" + "github.com/ActiveMemory/ctx/internal/sanitize" ) // add extracts MCP args and delegates to [handler.Add]. @@ -41,7 +43,20 @@ func add( if extractErr != nil { return out.ToolError(id, extractErr.Error()) } - t, addErr := handler.Add(d, entryType, content, extract.Opts(args)) + // MCP-SAN.2: Reject unknown entry types before writing. + if _, ok := entry.CtxFile(entryType); !ok { + return out.ToolError(id, fmt.Sprintf( + desc.Text(text.DescKeyMCPErrUnknownEntryType), + sanitize.Reflect(entryType, cfg.MaxNameLen), + )) + } + + // MCP-SAN.3: Sanitize content before writing to .context/. + content = sanitize.Content(content) + + t, addErr := handler.Add( + d, entryType, content, extract.SanitizedOpts(args), + ) return out.ToolResult(id, t, addErr) } @@ -64,6 +79,7 @@ func complete( id, desc.Text(text.DescKeyMCPErrQueryRequired), ) } + query = sanitize.Reflect(query, cfg.MaxQueryLen) t, completeErr := handler.Complete(d, query) return out.ToolResult(id, t, completeErr) } @@ -86,6 +102,11 @@ func journalSource( limit = int(v) } + // MCP-SAN.1: Cap source limit to a reasonable upper bound. + if limit > cfg.MaxSourceLimit { + limit = cfg.MaxSourceLimit + } + var since time.Time if sinceStr, _ := args[field.Since].(string); sinceStr != "" { var parseErr error @@ -122,8 +143,27 @@ func watchUpdate( if extractErr != nil { return out.ToolError(id, extractErr.Error()) } + // MCP-SAN.2: Reject unknown entry types (allow "complete" as + // special case handled by handler.WatchUpdate). + if entryType != entry.Complete { + if _, ok := entry.CtxFile(entryType); !ok { + return out.ToolError(id, fmt.Sprintf( + desc.Text( + text.DescKeyMCPErrUnknownEntryType, + ), + sanitize.Reflect( + entryType, cfg.MaxNameLen, + ), + )) + } + } + + // MCP-SAN.3: Sanitize content before writing to .context/. + content = sanitize.Content(content) + t, updateErr := handler.WatchUpdate( - d, entryType, content, extract.Opts(args), + d, entryType, content, + extract.SanitizedOpts(args), ) return out.ToolResult(id, t, updateErr) } @@ -189,6 +229,10 @@ func sessionEvent( ) } caller, _ := args[field.Caller].(string) + + // MCP-SAN.4: Sanitize caller before reflecting in response. + caller = sanitize.Reflect(caller, cfg.MaxCallerLen) + t, eventErr := handler.SessionEvent(d, eventType, caller) return out.ToolResult(id, t, eventErr) } diff --git a/internal/mcp/server/server_test.go b/internal/mcp/server/server_test.go index e1aceb694..b5d86844f 100644 --- a/internal/mcp/server/server_test.go +++ b/internal/mcp/server/server_test.go @@ -1024,6 +1024,36 @@ func TestPromptAddDecision(t *testing.T) { } } +func TestPromptAddLearning(t *testing.T) { + srv, _ := newTestServer(t) + resp := request(t, srv, "prompts/get", proto.GetPromptParams{ + Name: "ctx-learning-add", + Arguments: map[string]string{ + "content": "Always validate inputs", + "context": "MCP sanitization work", + "lesson": "Never trust external input", + "application": "Add validation at boundaries", + }, + }) + if resp.Error != nil { + t.Fatalf("unexpected error: %v", resp.Error.Message) + } + raw, _ := json.Marshal(resp.Result) + var result proto.GetPromptResult + if err := json.Unmarshal(raw, &result); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if len(result.Messages) == 0 { + t.Fatal("expected message in learning prompt") + } + text := result.Messages[0].Content.Text + if !strings.Contains(text, "Always validate inputs") { + t.Errorf( + "expected learning content in text, got: %s", text, + ) + } +} + func TestPromptReflect(t *testing.T) { srv, _ := newTestServer(t) resp := request(t, srv, "prompts/get", proto.GetPromptParams{ @@ -1379,3 +1409,80 @@ func TestToolSearchNoQuery(t *testing.T) { t.Error("expected error when query is missing") } } + +// --- Serve edge-case tests --- + +// errWriter is an io.Writer that always returns an error. +type errWriter struct{} + +func (errWriter) Write([]byte) (int, error) { + return 0, os.ErrClosed +} + +func TestServeEmptyLines(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed an empty line followed by a valid ping. + idBytes, _ := json.Marshal(1) + req := proto.Request{ + JSONRPC: "2.0", + ID: idBytes, + Method: "ping", + } + line, _ := json.Marshal(req) + + // Empty line + valid request. + input := append([]byte("\n"), line...) + input = append(input, '\n') + + var out bytes.Buffer + srv.in = bytes.NewReader(input) + srv.out = mcpIO.NewWriter(&out) + if err := srv.Serve(); err != nil { + t.Fatalf("serve: %v", err) + } + + var resp proto.Response + if err := json.Unmarshal(out.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if resp.Error != nil { + t.Errorf("unexpected error: %v", resp.Error.Message) + } +} + +func TestServeParseErrorWriteFailure(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed invalid JSON to trigger a parse error. + srv.in = bytes.NewReader([]byte("not-json\n")) + srv.out = mcpIO.NewWriter(errWriter{}) + + err := srv.Serve() + if err == nil { + t.Fatal("expected write error, got nil") + } +} + +func TestServeDispatchWriteFailure(t *testing.T) { + srv, _ := newTestServer(t) + + // Feed a valid request but use an errWriter for output. + idBytes, _ := json.Marshal(1) + req := proto.Request{ + JSONRPC: "2.0", + ID: idBytes, + Method: "ping", + } + line, _ := json.Marshal(req) + + srv.in = bytes.NewReader(append(line, '\n')) + srv.out = mcpIO.NewWriter(errWriter{}) + + // The marshal itself succeeds but the write fails, triggering + // the fallback error path which also fails, returning the error. + err := srv.Serve() + if err == nil { + t.Fatal("expected write error, got nil") + } +} diff --git a/internal/mcp/server/stat/stat_test.go b/internal/mcp/server/stat/stat_test.go new file mode 100644 index 000000000..be0b30ce3 --- /dev/null +++ b/internal/mcp/server/stat/stat_test.go @@ -0,0 +1,22 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package stat + +import "testing" + +func TestTotalAddsEmpty(t *testing.T) { + if got := TotalAdds(nil); got != 0 { + t.Errorf("TotalAdds(nil) = %d, want 0", got) + } +} + +func TestTotalAddsMultiple(t *testing.T) { + m := map[string]int{"decision": 2, "learning": 3, "convention": 1} + if got := TotalAdds(m); got != 6 { + t.Errorf("TotalAdds = %d, want 6", got) + } +} diff --git a/internal/sanitize/content.go b/internal/sanitize/content.go new file mode 100644 index 000000000..7441b06dc --- /dev/null +++ b/internal/sanitize/content.go @@ -0,0 +1,78 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "strings" + "unicode" + + "github.com/ActiveMemory/ctx/internal/config/regex" + cfgSan "github.com/ActiveMemory/ctx/internal/config/sanitize" + "github.com/ActiveMemory/ctx/internal/config/token" +) + +// Content neutralizes Markdown structure characters in entry content +// that could corrupt .context/ file parsing. +// +// Escapes entry headers, task checkboxes, and constitution rule +// patterns so they render as literal text instead of structural +// elements. +// +// Parameters: +// - s: raw content string from MCP client +// +// Returns: +// - string: content safe for appending to .context/ Markdown files +func Content(s string) string { + // Escape entry headers: "## [2026-" → "\\## [2026-" + s = regex.SanEntryHeader.ReplaceAllStringFunc( + s, func(m string) string { + return cfgSan.EscapePrefix + m + }, + ) + + // Escape task checkboxes: "- [ ]" → "\\- [ ]" + s = regex.SanTaskCheckbox.ReplaceAllStringFunc( + s, func(m string) string { + return cfgSan.EscapePrefix + m + }, + ) + + // Escape constitution rules. + s = regex.SanConstitutionRule.ReplaceAllStringFunc( + s, func(m string) string { + return cfgSan.EscapePrefix + m + }, + ) + + // Strip null bytes. + s = strings.ReplaceAll(s, cfgSan.NullByte, "") + + return s +} + +// StripControl removes ASCII control characters (except tab and +// newline) from a string. +// +// Parameters: +// - s: input string potentially containing control characters +// +// Returns: +// - string: input with control characters removed +func StripControl(s string) string { + return strings.Map(func(r rune) rune { + if r == rune(token.Tab[0]) || + r == rune(token.NewlineLF[0]) || + r == rune(token.NewlineCRLF[0]) { + return r + } + if unicode.IsControl(r) { + return -1 + } + return r + }, s) +} diff --git a/internal/sanitize/doc.go b/internal/sanitize/doc.go index e4be5cf17..d9a471a71 100644 --- a/internal/sanitize/doc.go +++ b/internal/sanitize/doc.go @@ -8,6 +8,9 @@ // // Unlike validation (which rejects bad input), sanitization mutates // input to conform to constraints. [Filename] converts arbitrary -// strings into safe filename components. +// strings into safe filename components, [Content] neutralizes +// Markdown structure injections, [Reflect] truncates and strips +// control characters for error messages, and [SessionID] produces +// path-safe session identifiers. // Part of the internal subsystem. package sanitize diff --git a/internal/sanitize/path.go b/internal/sanitize/path.go new file mode 100644 index 000000000..0b0c98cdb --- /dev/null +++ b/internal/sanitize/path.go @@ -0,0 +1,51 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "strings" + + "github.com/ActiveMemory/ctx/internal/config/regex" + cfgSan "github.com/ActiveMemory/ctx/internal/config/sanitize" +) + +// SessionID sanitizes a session identifier for safe use in file +// paths. +// +// Strips path separators, traversal sequences, and null bytes. +// Replaces remaining unsafe characters with hyphens and limits +// length to MaxSessionIDLen bytes. +// +// Parameters: +// - s: raw session ID from MCP client +// +// Returns: +// - string: path-safe session ID +func SessionID(s string) string { + // Strip null bytes. + s = strings.ReplaceAll(s, cfgSan.NullByte, "") + + // Collapse path traversal sequences. + s = strings.ReplaceAll(s, cfgSan.DotDot, "") + s = strings.ReplaceAll(s, cfgSan.ForwardSlash, "") + s = strings.ReplaceAll(s, cfgSan.Backslash, "") + + // Replace remaining unsafe chars. + s = regex.SanSessionIDUnsafe.ReplaceAllString( + s, cfgSan.HyphenReplace, + ) + + // Remove leading/trailing hyphens. + s = strings.Trim(s, cfgSan.HyphenReplace) + + // Limit length. + if len(s) > cfgSan.MaxSessionIDLen { + s = s[:cfgSan.MaxSessionIDLen] + } + + return s +} diff --git a/internal/sanitize/reflect.go b/internal/sanitize/reflect.go new file mode 100644 index 000000000..1b0271326 --- /dev/null +++ b/internal/sanitize/reflect.go @@ -0,0 +1,28 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +// Reflect truncates a string and strips control characters for safe +// inclusion in error or log messages. +// +// Use this for any client-supplied value that gets reflected back in +// JSON-RPC error messages (tool names, prompt names, URIs, caller +// identifiers). +// +// Parameters: +// - s: untrusted input string +// - maxLen: maximum allowed length (0 = no truncation) +// +// Returns: +// - string: truncated, control-character-free string +func Reflect(s string, maxLen int) string { + s = StripControl(s) + if maxLen > 0 && len(s) > maxLen { + s = s[:maxLen] + } + return s +} diff --git a/internal/sanitize/sanitize_test.go b/internal/sanitize/sanitize_test.go new file mode 100644 index 000000000..47863a626 --- /dev/null +++ b/internal/sanitize/sanitize_test.go @@ -0,0 +1,172 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +import ( + "strings" + "testing" +) + +func TestContentEscapesEntryHeaders(t *testing.T) { + input := "## [2026-03-15] Decision title" + got := Content(input) + want := `\## [2026-03-15] Decision title` + if got != want { + t.Errorf("Content(%q) = %q, want %q", input, got, want) + } +} + +func TestContentEscapesTaskCheckboxUnchecked(t *testing.T) { + got := Content("- [ ] New task") + want := `\- [ ] New task` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestContentEscapesTaskCheckboxChecked(t *testing.T) { + got := Content("- [x] Done task") + want := `\- [x] Done task` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestContentEscapesConstitutionRules(t *testing.T) { + input := "- [ ] **Never break the constitution" + got := Content(input) + if !strings.HasPrefix(got, `\- [ ] **Never`) { + t.Errorf("got %q, want constitution rule escaped", got) + } +} + +func TestContentStripsNullBytes(t *testing.T) { + got := Content("hello\x00world") + if got != "helloworld" { + t.Errorf("got %q, want %q", got, "helloworld") + } +} + +func TestContentPreservesNormalText(t *testing.T) { + input := "This is a normal architecture decision." + got := Content(input) + if got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestContentMultilineInjection(t *testing.T) { + input := "Legit\n## [2026-01-01] Injected\n- [ ] Fake" + got := Content(input) + if strings.Contains(got, "\n## [2026") { + t.Error("entry header injection not escaped") + } + if strings.Contains(got, "\n- [ ] Fake") { + t.Error("checkbox injection not escaped") + } +} + +func TestReflectTruncates(t *testing.T) { + got := Reflect(strings.Repeat("a", 500), 256) + if len(got) != 256 { + t.Errorf("len = %d, want 256", len(got)) + } +} + +func TestReflectStripsControlChars(t *testing.T) { + got := Reflect("tool\x07name\x1b[31m", 0) + if got != "toolname[31m" { + t.Errorf("got %q, want %q", got, "toolname[31m") + } +} + +func TestReflectPreservesNormal(t *testing.T) { + got := Reflect("ctx_status", 256) + if got != "ctx_status" { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestReflectZeroMaxLen(t *testing.T) { + got := Reflect(strings.Repeat("x", 1000), 0) + if len(got) != 1000 { + t.Errorf("len = %d, want 1000 (no truncation)", len(got)) + } +} + +func TestTruncateShort(t *testing.T) { + if got := truncate("short", 100); got != "short" { + t.Errorf("got %q", got) + } +} + +func TestTruncateLong(t *testing.T) { + if got := truncate("long input", 4); got != "long" { + t.Errorf("got %q", got) + } +} + +func TestTruncateZero(t *testing.T) { + if got := truncate("any", 0); got != "any" { + t.Errorf("got %q", got) + } +} + +func TestStripControlPreservesWhitespace(t *testing.T) { + input := "a\nb\tc\r" + if got := StripControl(input); got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestStripControlRemovesBell(t *testing.T) { + if got := StripControl("hello\x07world"); got != "helloworld" { + t.Errorf("got %q", got) + } +} + +func TestSessionIDSafe(t *testing.T) { + input := "session-2026-03-15" + if got := SessionID(input); got != input { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestSessionIDStripsTraversal(t *testing.T) { + got := SessionID("../../etc/passwd") + if strings.Contains(got, "..") || strings.Contains(got, "/") { + t.Errorf("got %q, contains traversal", got) + } +} + +func TestSessionIDStripsBackslashTraversal(t *testing.T) { + got := SessionID(`..\..\windows\system32`) + if strings.Contains(got, "..") || strings.Contains(got, `\`) { + t.Errorf("got %q, contains traversal", got) + } +} + +func TestSessionIDStripsNullBytes(t *testing.T) { + got := SessionID("session\x00evil") + if strings.Contains(got, "\x00") { + t.Errorf("got %q, contains null byte", got) + } +} + +func TestSessionIDLimitsLength(t *testing.T) { + got := SessionID(strings.Repeat("a", 300)) + if len(got) > 128 { + t.Errorf("len = %d, want <= 128", len(got)) + } +} + +func TestSessionIDReplacesUnsafe(t *testing.T) { + got := SessionID("session with spaces!@#$") + if strings.ContainsAny(got, " !@#$") { + t.Errorf("got %q, contains unsafe chars", got) + } +} diff --git a/internal/sanitize/truncate.go b/internal/sanitize/truncate.go new file mode 100644 index 000000000..cd2526a7a --- /dev/null +++ b/internal/sanitize/truncate.go @@ -0,0 +1,23 @@ +// / ctx: https://ctx.ist +// ,'`./ do you remember? +// `.,'\ +// \ Copyright 2026-present Context contributors. +// SPDX-License-Identifier: Apache-2.0 + +package sanitize + +// truncate limits a string to maxLen bytes. If truncated, no +// ellipsis is appended — the caller controls presentation. +// +// Parameters: +// - s: input string +// - maxLen: maximum byte length +// +// Returns: +// - string: input capped to maxLen bytes +func truncate(s string, maxLen int) string { + if maxLen > 0 && len(s) > maxLen { + return s[:maxLen] + } + return s +}