Skip to content

Commit 1470488

Browse files
authored
Merge pull request #1349 from krissetto/fix-bedrock-usage-tracking
Fix token usage and cost tracking in amazon-bedrock provider
2 parents eb11a6f + c2240b1 commit 1470488

2 files changed

Lines changed: 125 additions & 9 deletions

File tree

pkg/model/provider/bedrock/adapter.go

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package bedrock
22

33
import (
4+
"fmt"
45
"io"
56
"log/slog"
67

@@ -20,6 +21,12 @@ type streamAdapter struct {
2021
// State for accumulating tool call data
2122
currentToolID string
2223
currentToolName string
24+
25+
// Buffered state for proper event ordering
26+
// Bedrock sends MessageStop before Metadata, but runtime expects usage before FinishReason
27+
pendingFinishReason chat.FinishReason
28+
pendingUsage *chat.Usage
29+
metadataReceived bool
2330
}
2431

2532
func newStreamAdapter(stream *bedrockruntime.ConverseStreamEventStream, model string, trackUsage bool) *streamAdapter {
@@ -32,12 +39,64 @@ func newStreamAdapter(stream *bedrockruntime.ConverseStreamEventStream, model st
3239

3340
// Recv gets the next completion chunk
3441
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
42+
// If we have both finish reason and usage buffered, emit the final response
43+
// This handles both event orderings: MessageStop→Metadata and Metadata→MessageStop
44+
if a.pendingFinishReason != "" && a.metadataReceived {
45+
slog.Debug("Bedrock stream: emitting buffered final response",
46+
"finish_reason", a.pendingFinishReason,
47+
"has_usage", a.pendingUsage != nil)
48+
response := chat.MessageStreamResponse{
49+
Object: "chat.completion.chunk",
50+
Model: a.model,
51+
Choices: []chat.MessageStreamChoice{
52+
{
53+
Index: 0,
54+
FinishReason: a.pendingFinishReason,
55+
Delta: chat.MessageDelta{
56+
Role: string(chat.MessageRoleAssistant),
57+
},
58+
},
59+
},
60+
Usage: a.pendingUsage,
61+
}
62+
// Clear pending state
63+
a.pendingFinishReason = ""
64+
a.pendingUsage = nil
65+
a.metadataReceived = false
66+
return response, nil
67+
}
68+
3569
event, ok := <-a.stream.Events()
3670
if !ok {
3771
// Check for errors
3872
if err := a.stream.Err(); err != nil {
73+
slog.Debug("Bedrock stream: error on channel close", "error", err)
3974
return chat.MessageStreamResponse{}, err
4075
}
76+
// If we have a pending finish reason but never got metadata, emit it now
77+
if a.pendingFinishReason != "" {
78+
slog.Debug("Bedrock stream: channel closed, emitting pending finish reason without metadata",
79+
"finish_reason", a.pendingFinishReason,
80+
"has_usage", a.pendingUsage != nil)
81+
response := chat.MessageStreamResponse{
82+
Object: "chat.completion.chunk",
83+
Model: a.model,
84+
Choices: []chat.MessageStreamChoice{
85+
{
86+
Index: 0,
87+
FinishReason: a.pendingFinishReason,
88+
Delta: chat.MessageDelta{
89+
Role: string(chat.MessageRoleAssistant),
90+
},
91+
},
92+
},
93+
Usage: a.pendingUsage,
94+
}
95+
a.pendingFinishReason = ""
96+
a.pendingUsage = nil
97+
return response, nil
98+
}
99+
slog.Debug("Bedrock stream: channel closed, returning EOF")
41100
return chat.MessageStreamResponse{}, io.EOF
42101
}
43102

@@ -103,41 +162,59 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
103162
slog.Debug("Bedrock stream: content block stop", "index", ev.Value.ContentBlockIndex)
104163

105164
case *types.ConverseStreamOutputMemberMessageStop:
106-
// Message complete - determine finish reason
165+
// Buffer the finish reason - don't emit it yet, wait for metadata with usage
166+
// Bedrock sends MessageStop before Metadata, but runtime returns early on FinishReason
107167
stopReason := ev.Value.StopReason
108168
switch stopReason {
109169
case types.StopReasonToolUse:
110-
response.Choices[0].FinishReason = chat.FinishReasonToolCalls
170+
a.pendingFinishReason = chat.FinishReasonToolCalls
111171
case types.StopReasonEndTurn, types.StopReasonStopSequence:
112-
response.Choices[0].FinishReason = chat.FinishReasonStop
172+
a.pendingFinishReason = chat.FinishReasonStop
113173
case types.StopReasonMaxTokens:
114-
response.Choices[0].FinishReason = chat.FinishReasonLength
174+
a.pendingFinishReason = chat.FinishReasonLength
115175
default:
116-
response.Choices[0].FinishReason = chat.FinishReasonStop
176+
a.pendingFinishReason = chat.FinishReasonStop
117177
}
178+
slog.Debug("Bedrock stream: message stop (buffered)",
179+
"stop_reason", stopReason,
180+
"pending_finish_reason", a.pendingFinishReason,
181+
"metadata_already_received", a.metadataReceived)
118182

119183
case *types.ConverseStreamOutputMemberMetadata:
120-
// Metadata event with usage info - always capture if available
184+
// Metadata event with usage info - capture and mark received
185+
a.metadataReceived = true
186+
slog.Debug("Bedrock stream: received metadata event",
187+
"has_usage", ev.Value.Usage != nil,
188+
"finish_reason_already_received", a.pendingFinishReason != "")
189+
121190
if ev.Value.Usage != nil {
122191
usage := ev.Value.Usage
123-
slog.Debug("Bedrock stream: received usage metadata",
192+
slog.Debug("Bedrock stream: usage metadata details",
124193
"input_tokens", derefInt32(usage.InputTokens),
125194
"output_tokens", derefInt32(usage.OutputTokens),
126195
"cache_read_tokens", derefInt32(usage.CacheReadInputTokens),
127196
"cache_write_tokens", derefInt32(usage.CacheWriteInputTokens),
128197
"track_usage", a.trackUsage)
129198

130199
if a.trackUsage {
131-
response.Usage = &chat.Usage{
200+
a.pendingUsage = &chat.Usage{
132201
InputTokens: int64(derefInt32(usage.InputTokens)),
133202
OutputTokens: int64(derefInt32(usage.OutputTokens)),
134203
CachedInputTokens: int64(derefInt32(usage.CacheReadInputTokens)),
135204
CacheWriteTokens: int64(derefInt32(usage.CacheWriteInputTokens)),
136205
}
206+
slog.Debug("Bedrock stream: usage captured in pendingUsage",
207+
"input", a.pendingUsage.InputTokens,
208+
"output", a.pendingUsage.OutputTokens)
209+
} else {
210+
slog.Debug("Bedrock stream: usage NOT captured (trackUsage is false)")
137211
}
138212
} else {
139-
slog.Debug("Bedrock stream: metadata event has no usage data")
213+
slog.Debug("Bedrock stream: metadata event has nil Usage field")
140214
}
215+
216+
default:
217+
slog.Debug("Bedrock stream: unknown event type", "type", fmt.Sprintf("%T", event))
141218
}
142219

143220
return response, nil

pkg/modelsdev/store.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,25 @@ func (s *Store) GetModel(ctx context.Context, id string) (*Model, error) {
177177

178178
model, exists := provider.Models[modelID]
179179
if !exists {
180+
// For amazon-bedrock, try stripping region/inference profile prefixes
181+
// Bedrock uses prefixes like "global.", "us.", "eu.", "apac." etc. for
182+
// cross-region inference profiles, but models.dev stores models without
183+
// these prefixes. Try stripping the first segment if it doesn't match
184+
// a known model provider prefix (anthropic, meta, amazon, etc.)
185+
if providerID == "amazon-bedrock" {
186+
if idx := strings.Index(modelID, "."); idx != -1 {
187+
possibleRegionPrefix := modelID[:idx]
188+
// Only strip if the prefix is NOT a known model provider
189+
// (i.e., it's likely a region prefix like "global", "us", "eu")
190+
if !isBedrockModelProvider(possibleRegionPrefix) {
191+
normalizedModelID := modelID[idx+1:]
192+
model, exists = provider.Models[normalizedModelID]
193+
if exists {
194+
return &model, nil
195+
}
196+
}
197+
}
198+
}
180199
return nil, fmt.Errorf("model %q not found in provider %q", modelID, providerID)
181200
}
182201

@@ -316,3 +335,23 @@ func (s *Store) ResolveModelAlias(ctx context.Context, providerID, modelName str
316335

317336
return modelName
318337
}
338+
339+
// bedrockModelProviders contains known model provider prefixes used in Bedrock model IDs.
340+
// These are NOT region prefixes and should not be stripped when normalizing model IDs.
341+
var bedrockModelProviders = map[string]bool{
342+
"anthropic": true,
343+
"amazon": true,
344+
"meta": true,
345+
"cohere": true,
346+
"ai21": true,
347+
"mistral": true,
348+
"stability": true,
349+
"deepseek": true,
350+
"google": true,
351+
"minimax": true,
352+
}
353+
354+
// isBedrockModelProvider returns true if the prefix is a known Bedrock model provider.
355+
func isBedrockModelProvider(prefix string) bool {
356+
return bedrockModelProviders[prefix]
357+
}

0 commit comments

Comments
 (0)