@@ -20,13 +20,15 @@ import (
2020 "github.com/docker/cagent/pkg/environment"
2121 "github.com/docker/cagent/pkg/model/provider/base"
2222 "github.com/docker/cagent/pkg/model/provider/options"
23+ "github.com/docker/cagent/pkg/modelsdev"
2324 "github.com/docker/cagent/pkg/tools"
2425)
2526
2627// Client represents a Bedrock client wrapper implementing provider.Provider
2728type Client struct {
2829 base.Config
29- bedrockClient * bedrockruntime.Client
30+ bedrockClient * bedrockruntime.Client
31+ cachingSupported bool // Cached at init time for efficiency
3032}
3133
3234// bearerTokenTransport adds Authorization header with bearer token to requests
@@ -40,7 +42,6 @@ func (t *bearerTokenTransport) RoundTrip(req *http.Request) (*http.Response, err
4042 return t .base .RoundTrip (req )
4143}
4244
43- // NewClient creates a new Bedrock client from the provided configuration
4445func NewClient (ctx context.Context , cfg * latest.ModelConfig , env environment.Provider , opts ... options.Opt ) (* Client , error ) {
4546 if cfg == nil {
4647 slog .Error ("Bedrock client creation failed" , "error" , "model configuration is required" )
@@ -109,19 +110,47 @@ func NewClient(ctx context.Context, cfg *latest.ModelConfig, env environment.Pro
109110
110111 bedrockClient := bedrockruntime .NewFromConfig (awsCfg , clientOpts ... )
111112
112- slog .Debug ("Bedrock client created successfully" , "model" , cfg .Model , "region" , awsCfg .Region )
113+ // Detect prompt caching capability at init time for efficiency.
114+ // Uses models.dev cache pricing as proxy for capability detection.
115+ cachingSupported := detectCachingSupport (ctx , cfg .Model )
116+
117+ slog .Debug ("Bedrock client created successfully" ,
118+ "model" , cfg .Model ,
119+ "region" , awsCfg .Region ,
120+ "caching_supported" , cachingSupported )
113121
114122 return & Client {
115123 Config : base.Config {
116124 ModelConfig : * cfg ,
117125 ModelOptions : globalOptions ,
118126 Env : env ,
119127 },
120- bedrockClient : bedrockClient ,
128+ bedrockClient : bedrockClient ,
129+ cachingSupported : cachingSupported ,
121130 }, nil
122131}
123132
124- // buildAWSConfig creates AWS config with proper credentials using the default credential chain
133+ // detectCachingSupport checks if a model supports prompt caching using models.dev data.
134+ // Models with non-zero CacheRead/CacheWrite costs support prompt caching.
135+ // Returns false on lookup failure (safe default for unsupported models).
136+ func detectCachingSupport (ctx context.Context , model string ) bool {
137+ store , err := modelsdev .NewStore ()
138+ if err != nil {
139+ slog .Debug ("Bedrock models store unavailable, prompt caching disabled" , "error" , err )
140+ return false
141+ }
142+
143+ modelID := "amazon-bedrock/" + model
144+ m , err := store .GetModel (ctx , modelID )
145+ if err != nil {
146+ slog .Debug ("Bedrock prompt caching disabled: model not found in models.dev" ,
147+ "model_id" , modelID , "error" , err )
148+ return false
149+ }
150+
151+ return m .Cost != nil && (m .Cost .CacheRead > 0 || m .Cost .CacheWrite > 0 )
152+ }
153+
125154func buildAWSConfig (ctx context.Context , cfg * latest.ModelConfig , env environment.Provider ) (aws.Config , error ) {
126155 var configOpts []func (* config.LoadOptions ) error
127156
@@ -169,7 +198,6 @@ func buildAWSConfig(ctx context.Context, cfg *latest.ModelConfig, env environmen
169198 return awsCfg , nil
170199}
171200
172- // CreateChatCompletionStream creates a streaming chat completion request
173201func (c * Client ) CreateChatCompletionStream (
174202 ctx context.Context ,
175203 messages []chat.Message ,
@@ -198,21 +226,22 @@ func (c *Client) CreateChatCompletionStream(
198226 return newStreamAdapter (output .GetStream (), c .ModelConfig .Model , trackUsage ), nil
199227}
200228
201- // buildConverseStreamInput creates the ConverseStream input parameters
202229func (c * Client ) buildConverseStreamInput (messages []chat.Message , requestTools []tools.Tool ) * bedrockruntime.ConverseStreamInput {
203230 input := & bedrockruntime.ConverseStreamInput {
204231 ModelId : aws .String (c .ModelConfig .Model ),
205232 }
206233
234+ enableCaching := c .promptCachingEnabled ()
235+
207236 // Convert and set messages (excluding system)
208- input .Messages , input .System = convertMessages (messages )
237+ input .Messages , input .System = convertMessages (messages , enableCaching )
209238
210239 // Set inference configuration
211240 input .InferenceConfig = c .buildInferenceConfig ()
212241
213242 // Convert and set tools
214243 if len (requestTools ) > 0 {
215- input .ToolConfig = convertToolConfig (requestTools )
244+ input .ToolConfig = convertToolConfig (requestTools , enableCaching )
216245 }
217246
218247 // Set extended thinking configuration for Claude models
@@ -223,7 +252,6 @@ func (c *Client) buildConverseStreamInput(messages []chat.Message, requestTools
223252 return input
224253}
225254
226- // buildInferenceConfig creates the inference configuration
227255func (c * Client ) buildInferenceConfig () * types.InferenceConfiguration {
228256 cfg := & types.InferenceConfiguration {}
229257
@@ -247,8 +275,8 @@ func (c *Client) buildInferenceConfig() *types.InferenceConfiguration {
247275 return cfg
248276}
249277
250- // isThinkingEnabled checks if extended thinking will be enabled for this request.
251- // This mirrors the validation logic in buildAdditionalModelRequestFields .
278+ // isThinkingEnabled mirrors the validation in buildAdditionalModelRequestFields
279+ // to determine if thinking params will affect inference config (temp/topP constraints) .
252280func (c * Client ) isThinkingEnabled () bool {
253281 if c .ModelConfig .ThinkingBudget == nil || c .ModelConfig .ThinkingBudget .Tokens <= 0 {
254282 return false
@@ -269,13 +297,18 @@ func (c *Client) isThinkingEnabled() bool {
269297 return true
270298}
271299
272- // interleavedThinkingEnabled returns true when provider_opts.interleaved_thinking is set.
273300func (c * Client ) interleavedThinkingEnabled () bool {
274301 return getProviderOpt [bool ](c .ModelConfig .ProviderOpts , "interleaved_thinking" )
275302}
276303
277- // buildAdditionalModelRequestFields creates model-specific parameters.
278- // Used for extended thinking (reasoning) configuration on Claude models.
304+ func (c * Client ) promptCachingEnabled () bool {
305+ if getProviderOpt [bool ](c .ModelConfig .ProviderOpts , "disable_prompt_caching" ) {
306+ return false
307+ }
308+ return c .cachingSupported
309+ }
310+
311+ // buildAdditionalModelRequestFields configures Claude's extended thinking (reasoning) mode.
279312func (c * Client ) buildAdditionalModelRequestFields () document.Interface {
280313 if c .ModelConfig .ThinkingBudget == nil || c .ModelConfig .ThinkingBudget .Tokens <= 0 {
281314 return nil
@@ -316,7 +349,6 @@ func (c *Client) buildAdditionalModelRequestFields() document.Interface {
316349 return document .NewLazyDocument (fields )
317350}
318351
319- // getProviderOpt extracts a typed value from provider_opts
320352func getProviderOpt [T any ](opts map [string ]any , key string ) T {
321353 var zero T
322354 if opts == nil {
@@ -328,6 +360,11 @@ func getProviderOpt[T any](opts map[string]any, key string) T {
328360 }
329361 typed , ok := v .(T )
330362 if ! ok {
363+ slog .Warn ("Bedrock provider_opts type mismatch" ,
364+ "key" , key ,
365+ "expected_type" , fmt .Sprintf ("%T" , zero ),
366+ "actual_type" , fmt .Sprintf ("%T" , v ),
367+ "value" , v )
331368 return zero
332369 }
333370 return typed
0 commit comments