Skip to content

Commit 2dcc3a2

Browse files
committed
handle system messages correctly with anthropic provider
greatly reduces the amount of model refusals, and gets anthropic models to behave much more in line with the actual system prompt instructions Signed-off-by: Christopher Petito <chrisjpetito@gmail.com>
1 parent efb6338 commit 2dcc3a2

2 files changed

Lines changed: 124 additions & 4 deletions

File tree

pkg/model/provider/anthropic/client.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ func (c *Client) CreateChatCompletionStream(
123123
Tools: convertTools(requestTools),
124124
}
125125

126+
// Populate proper Anthropic system prompt from input messages
127+
if sys := extractSystemBlocks(messages); len(sys) > 0 {
128+
params.System = sys
129+
}
130+
126131
if len(requestTools) > 0 {
127132
slog.Debug("Adding tools to Anthropic request", "tool_count", len(requestTools))
128133
}
@@ -164,6 +169,11 @@ func (c *Client) CreateChatCompletion(
164169
Messages: convertMessages(messages),
165170
}
166171

172+
// Populate proper Anthropic system prompt from input messages
173+
if sys := extractSystemBlocks(messages); len(sys) > 0 {
174+
params.System = sys
175+
}
176+
167177
// Build a fresh client per request when using the gateway
168178
client := c.client
169179
if c.useGateway {
@@ -185,10 +195,7 @@ func convertMessages(messages []chat.Message) []anthropic.MessageParam {
185195
for i := range messages {
186196
msg := &messages[i]
187197
if msg.Role == chat.MessageRoleSystem {
188-
// Convert system message to user message with system prefix
189-
if systemContent := strings.TrimSpace("System: " + msg.Content); systemContent != "System:" {
190-
anthropicMessages = append(anthropicMessages, anthropic.NewAssistantMessage(anthropic.NewTextBlock(systemContent)))
191-
}
198+
// System messages are handled via the top-level params.System
192199
continue
193200
}
194201
if msg.Role == chat.MessageRoleUser {
@@ -301,6 +308,30 @@ func convertMessages(messages []chat.Message) []anthropic.MessageParam {
301308
return anthropicMessages
302309
}
303310

311+
// extractSystemBlocks converts any system-role messages into Anthropic system text blocks
312+
// to be set on the top-level MessageNewParams.System field.
313+
func extractSystemBlocks(messages []chat.Message) []anthropic.TextBlockParam {
314+
var systemBlocks []anthropic.TextBlockParam
315+
for i := range messages {
316+
msg := &messages[i]
317+
if msg.Role != chat.MessageRoleSystem {
318+
continue
319+
}
320+
if len(msg.MultiContent) > 0 {
321+
for _, part := range msg.MultiContent {
322+
if part.Type == chat.MessagePartTypeText {
323+
if txt := strings.TrimSpace(part.Text); txt != "" {
324+
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{Text: txt})
325+
}
326+
}
327+
}
328+
} else if txt := strings.TrimSpace(msg.Content); txt != "" {
329+
systemBlocks = append(systemBlocks, anthropic.TextBlockParam{Text: txt})
330+
}
331+
}
332+
return systemBlocks
333+
}
334+
304335
func convertTools(tooles []tools.Tool) []anthropic.ToolUnionParam {
305336
toolParams := make([]anthropic.ToolParam, len(tooles))
306337

pkg/model/provider/anthropic/client_test.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package anthropic
22

33
import (
44
"encoding/json"
5+
"strings"
56
"testing"
67

78
"github.com/docker/cagent/pkg/chat"
@@ -118,3 +119,91 @@ func TestConvertMessages_AssistantToolCalls_NoText_IncludesToolUse(t *testing.T)
118119
t.Fatalf("expected content block type 'tool_use', got %v", typ)
119120
}
120121
}
122+
123+
func TestSystemMessages_AreExtractedAndNotInMessageList(t *testing.T) {
124+
msgs := []chat.Message{
125+
{Role: chat.MessageRoleSystem, Content: " system rules here "},
126+
{Role: chat.MessageRoleUser, Content: "hi"},
127+
}
128+
129+
// System blocks should be extracted
130+
sys := extractSystemBlocks(msgs)
131+
if len(sys) != 1 {
132+
t.Fatalf("expected 1 system block, got %d", len(sys))
133+
}
134+
if strings.TrimSpace(sys[0].Text) != "system rules here" {
135+
t.Fatalf("unexpected system text: %q", sys[0].Text)
136+
}
137+
138+
// System role messages must not appear in the anthropic messages list
139+
out := convertMessages(msgs)
140+
if len(out) != 1 {
141+
t.Fatalf("expected 1 non-system message, got %d", len(out))
142+
}
143+
}
144+
145+
func TestSystemMessages_MultipleExtractedAndExcludedFromMessageList(t *testing.T) {
146+
msgs := []chat.Message{
147+
{Role: chat.MessageRoleSystem, Content: " sys A "},
148+
{Role: chat.MessageRoleSystem, Content: "\n sys B \t"},
149+
{Role: chat.MessageRoleUser, Content: "hello"},
150+
}
151+
152+
sys := extractSystemBlocks(msgs)
153+
if len(sys) != 2 {
154+
t.Fatalf("expected 2 system blocks, got %d", len(sys))
155+
}
156+
if strings.TrimSpace(sys[0].Text) != "sys A" {
157+
t.Fatalf("unexpected first system text: %q", sys[0].Text)
158+
}
159+
if strings.TrimSpace(sys[1].Text) != "sys B" {
160+
t.Fatalf("unexpected second system text: %q", sys[1].Text)
161+
}
162+
163+
out := convertMessages(msgs)
164+
if len(out) != 1 {
165+
t.Fatalf("expected 1 non-system message, got %d", len(out))
166+
}
167+
}
168+
169+
func TestSystemMessages_InterspersedExtractedAndExcluded(t *testing.T) {
170+
msgs := []chat.Message{
171+
{Role: chat.MessageRoleSystem, Content: " S1 "},
172+
{Role: chat.MessageRoleUser, Content: "U1"},
173+
{Role: chat.MessageRoleAssistant, Content: "A1"},
174+
{Role: chat.MessageRoleSystem, Content: "S2"},
175+
{Role: chat.MessageRoleUser, Content: " U2 "},
176+
}
177+
178+
// All system messages should be extracted in order of appearance
179+
sys := extractSystemBlocks(msgs)
180+
if len(sys) != 2 {
181+
t.Fatalf("expected 2 system blocks, got %d", len(sys))
182+
}
183+
if strings.TrimSpace(sys[0].Text) != "S1" {
184+
t.Fatalf("unexpected first system text: %q", sys[0].Text)
185+
}
186+
if strings.TrimSpace(sys[1].Text) != "S2" {
187+
t.Fatalf("unexpected second system text: %q", sys[1].Text)
188+
}
189+
190+
// Converted messages must exclude system roles and preserve order of others
191+
out := convertMessages(msgs)
192+
if len(out) != 3 {
193+
t.Fatalf("expected 3 non-system messages, got %d", len(out))
194+
}
195+
// Check roles: user, assistant, user
196+
for i, expected := range []string{"user", "assistant", "user"} {
197+
b, err := json.Marshal(out[i])
198+
if err != nil {
199+
t.Fatalf("marshal error: %v", err)
200+
}
201+
var m map[string]any
202+
if err := json.Unmarshal(b, &m); err != nil {
203+
t.Fatalf("unmarshal error: %v", err)
204+
}
205+
if role, _ := m["role"].(string); role != expected {
206+
t.Fatalf("unexpected role at %d: got %q want %q", i, role, expected)
207+
}
208+
}
209+
}

0 commit comments

Comments
 (0)