Skip to content

Commit fc56a07

Browse files
authored
Merge pull request #1360 from mfenderov/feature/bedrock-prompt-caching
feat(bedrock): add prompt caching for supported models
2 parents 9abdafc + f4bb8b1 commit fc56a07

4 files changed

Lines changed: 393 additions & 51 deletions

File tree

docs/USAGE.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,19 @@ models:
576576
| `role_session_name` | string | Session name for assumed role | cagent-bedrock-session |
577577
| `external_id` | string | External ID for role assumption | (none) |
578578
| `endpoint_url` | string | Custom endpoint (VPC/testing) | (none) |
579+
| `interleaved_thinking` | bool | Enable reasoning during tool calls (requires thinking_budget) | false |
580+
| `disable_prompt_caching` | bool | Disable automatic prompt caching | false |
581+
582+
#### Prompt Caching (Bedrock)
583+
584+
Prompt caching is automatically enabled for models that support it (detected via models.dev) to reduce latency and costs. System prompts, tool definitions, and recent messages are cached with a 5-minute TTL.
585+
586+
To disable:
587+
588+
```yaml
589+
provider_opts:
590+
disable_prompt_caching: true
591+
```
579592

580593
**Supported models (via Converse API):**
581594

pkg/model/provider/bedrock/client.go

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,15 @@ import (
2020
"github.com/docker/cagent/pkg/environment"
2121
"github.com/docker/cagent/pkg/model/provider/base"
2222
"github.com/docker/cagent/pkg/model/provider/options"
23+
"github.com/docker/cagent/pkg/modelsdev"
2324
"github.com/docker/cagent/pkg/tools"
2425
)
2526

2627
// Client represents a Bedrock client wrapper implementing provider.Provider
2728
type Client struct {
2829
base.Config
29-
bedrockClient *bedrockruntime.Client
30+
bedrockClient *bedrockruntime.Client
31+
cachingSupported bool // Cached at init time for efficiency
3032
}
3133

3234
// bearerTokenTransport adds Authorization header with bearer token to requests
@@ -40,7 +42,6 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err
4042
return t.base.RoundTrip(req)
4143
}
4244

43-
// NewClient creates a new Bedrock client from the provided configuration
4445
func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (*Client, error) {
4546
if cfg == nil {
4647
slog.Error("Bedrock client creation failed", "error", "model configuration is required")
@@ -109,19 +110,47 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
109110

110111
bedrockClient := bedrockruntime.NewFromConfig(awsCfg, clientOpts...)
111112

112-
slog.Debug("Bedrock client created successfully", "model", cfg.Model, "region", awsCfg.Region)
113+
// Detect prompt caching capability at init time for efficiency.
114+
// Uses models.dev cache pricing as proxy for capability detection.
115+
cachingSupported := detectCachingSupport(ctx, cfg.Model)
116+
117+
slog.Debug("Bedrock client created successfully",
118+
"model", cfg.Model,
119+
"region", awsCfg.Region,
120+
"caching_supported", cachingSupported)
113121

114122
return &Client{
115123
Config: base.Config{
116124
ModelConfig: *cfg,
117125
ModelOptions: globalOptions,
118126
Env: env,
119127
},
120-
bedrockClient: bedrockClient,
128+
bedrockClient: bedrockClient,
129+
cachingSupported: cachingSupported,
121130
}, nil
122131
}
123132

124-
// buildAWSConfig creates AWS config with proper credentials using the default credential chain
133+
// detectCachingSupport checks if a model supports prompt caching using models.dev data.
134+
// Models with non-zero CacheRead/CacheWrite costs support prompt caching.
135+
// Returns false on lookup failure (safe default for unsupported models).
136+
func detectCachingSupport(ctx context.Context, model string) bool {
137+
store, err := modelsdev.NewStore()
138+
if err != nil {
139+
slog.Debug("Bedrock models store unavailable, prompt caching disabled", "error", err)
140+
return false
141+
}
142+
143+
modelID := "amazon-bedrock/" + model
144+
m, err := store.GetModel(ctx, modelID)
145+
if err != nil {
146+
slog.Debug("Bedrock prompt caching disabled: model not found in models.dev",
147+
"model_id", modelID, "error", err)
148+
return false
149+
}
150+
151+
return m.Cost != nil && (m.Cost.CacheRead > 0 || m.Cost.CacheWrite > 0)
152+
}
153+
125154
func buildAWSConfig(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider) (aws.Config, error) {
126155
var configOpts []func(*config.LoadOptions) error
127156

@@ -169,7 +198,6 @@ func buildAWSConfig(ctx context.Context, cfg *latest.ModelConfig, env environmen
169198
return awsCfg, nil
170199
}
171200

172-
// CreateChatCompletionStream creates a streaming chat completion request
173201
func (c *Client) CreateChatCompletionStream(
174202
ctx context.Context,
175203
messages []chat.Message,
@@ -198,21 +226,22 @@ func (c *Client) CreateChatCompletionStream(
198226
return newStreamAdapter(output.GetStream(), c.ModelConfig.Model, trackUsage), nil
199227
}
200228

201-
// buildConverseStreamInput creates the ConverseStream input parameters
202229
func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools []tools.Tool) *bedrockruntime.ConverseStreamInput {
203230
input := &bedrockruntime.ConverseStreamInput{
204231
ModelId: aws.String(c.ModelConfig.Model),
205232
}
206233

234+
enableCaching := c.promptCachingEnabled()
235+
207236
// Convert and set messages (excluding system)
208-
input.Messages, input.System = convertMessages(messages)
237+
input.Messages, input.System = convertMessages(messages, enableCaching)
209238

210239
// Set inference configuration
211240
input.InferenceConfig = c.buildInferenceConfig()
212241

213242
// Convert and set tools
214243
if len(requestTools) > 0 {
215-
input.ToolConfig = convertToolConfig(requestTools)
244+
input.ToolConfig = convertToolConfig(requestTools, enableCaching)
216245
}
217246

218247
// Set extended thinking configuration for Claude models
@@ -223,7 +252,6 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
223252
return input
224253
}
225254

226-
// buildInferenceConfig creates the inference configuration
227255
func (c *Client) buildInferenceConfig() *types.InferenceConfiguration {
228256
cfg := &types.InferenceConfiguration{}
229257

@@ -247,8 +275,8 @@ func (c *Client) buildInferenceConfig() *types.InferenceConfiguration {
247275
return cfg
248276
}
249277

250-
// isThinkingEnabled checks if extended thinking will be enabled for this request.
251-
// This mirrors the validation logic in buildAdditionalModelRequestFields.
278+
// isThinkingEnabled mirrors the validation in buildAdditionalModelRequestFields
279+
// to determine if thinking params will affect inference config (temp/topP constraints).
252280
func (c *Client) isThinkingEnabled() bool {
253281
if c.ModelConfig.ThinkingBudget == nil || c.ModelConfig.ThinkingBudget.Tokens <= 0 {
254282
return false
@@ -269,13 +297,18 @@ func (c *Client) isThinkingEnabled() bool {
269297
return true
270298
}
271299

272-
// interleavedThinkingEnabled returns true when provider_opts.interleaved_thinking is set.
273300
func (c *Client) interleavedThinkingEnabled() bool {
274301
return getProviderOpt[bool](c.ModelConfig.ProviderOpts, "interleaved_thinking")
275302
}
276303

277-
// buildAdditionalModelRequestFields creates model-specific parameters.
278-
// Used for extended thinking (reasoning) configuration on Claude models.
304+
func (c *Client) promptCachingEnabled() bool {
305+
if getProviderOpt[bool](c.ModelConfig.ProviderOpts, "disable_prompt_caching") {
306+
return false
307+
}
308+
return c.cachingSupported
309+
}
310+
311+
// buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
279312
func (c *Client) buildAdditionalModelRequestFields() document.Interface {
280313
if c.ModelConfig.ThinkingBudget == nil || c.ModelConfig.ThinkingBudget.Tokens <= 0 {
281314
return nil
@@ -316,7 +349,6 @@ func (c *Client) buildAdditionalModelRequestFields() document.Interface {
316349
return document.NewLazyDocument(fields)
317350
}
318351

319-
// getProviderOpt extracts a typed value from provider_opts
320352
func getProviderOpt[T any](opts map[string]any, key string) T {
321353
var zero T
322354
if opts == nil {
@@ -328,6 +360,11 @@ func getProviderOpt[T any](opts map[string]any, key string) T {
328360
}
329361
typed, ok := v.(T)
330362
if !ok {
363+
slog.Warn("Bedrock provider_opts type mismatch",
364+
"key", key,
365+
"expected_type", fmt.Sprintf("%T", zero),
366+
"actual_type", fmt.Sprintf("%T", v),
367+
"value", v)
331368
return zero
332369
}
333370
return typed

0 commit comments

Comments
 (0)