Skip to content

Commit 7736f33

Browse files
committed
Don't track usage when asked not to
Signed-off-by: David Gageot <david.gageot@docker.com>
1 parent 9c0158e commit 7736f33

6 files changed

Lines changed: 42 additions & 30 deletions

File tree

pkg/model/provider/anthropic/adapter.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ import (
1313

1414
// streamAdapter adapts the Anthropic stream to our interface
1515
type streamAdapter struct {
16-
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
17-
toolCall bool
18-
toolID string
16+
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
17+
trackUsage bool
18+
toolCall bool
19+
toolID string
1920
}
2021

21-
func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion]) *streamAdapter {
22+
func newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
2223
return &streamAdapter{
23-
stream: stream,
24+
stream: stream,
25+
trackUsage: trackUsage,
2426
}
2527
}
2628

@@ -96,11 +98,13 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
9698
return response, fmt.Errorf("unknown delta type: %T", deltaVariant)
9799
}
98100
case anthropic.MessageDeltaEvent:
99-
response.Usage = &chat.Usage{
100-
InputTokens: int(eventVariant.Usage.InputTokens),
101-
OutputTokens: int(eventVariant.Usage.OutputTokens),
102-
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
103-
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
101+
if a.trackUsage {
102+
response.Usage = &chat.Usage{
103+
InputTokens: int(eventVariant.Usage.InputTokens),
104+
OutputTokens: int(eventVariant.Usage.OutputTokens),
105+
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
106+
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
107+
}
104108
}
105109
case anthropic.MessageStopEvent:
106110
if a.toolCall {

pkg/model/provider/anthropic/client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,9 @@ func (c *Client) CreateChatCompletionStream(
256256
}
257257

258258
stream := client.Messages.NewStreaming(ctx, params)
259-
ad := newStreamAdapter(stream)
259+
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
260+
ad := newStreamAdapter(stream, trackUsage)
261+
260262
slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model)
261263
return ad, nil
262264
}

pkg/model/provider/gemini/adapter.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
type StreamAdapter struct {
2020
ch chan result
2121
model string
22+
trackUsage bool
2223
mu sync.Mutex
2324
lastResponse *genai.GenerateContentResponse // Store last response for final message
2425
}
@@ -30,10 +31,11 @@ type result struct {
3031
}
3132

3233
// NewStreamAdapter constructs a StreamAdapter from Gemini's iterator
33-
func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string) *StreamAdapter {
34+
func NewStreamAdapter(iter func(func(*genai.GenerateContentResponse, error) bool), model string, trackUsage bool) *StreamAdapter {
3435
adapter := &StreamAdapter{
35-
ch: make(chan result),
36-
model: model,
36+
ch: make(chan result),
37+
model: model,
38+
trackUsage: trackUsage,
3739
}
3840

3941
go func() {
@@ -173,7 +175,7 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
173175
resp.ID = res.resp.ResponseID
174176

175177
// Handle token usage if present
176-
if res.resp.UsageMetadata != nil {
178+
if res.resp.UsageMetadata != nil && g.trackUsage {
177179
resp.Usage = &chat.Usage{
178180
InputTokens: int(res.resp.UsageMetadata.PromptTokenCount),
179181
OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount),

pkg/model/provider/gemini/adapter_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func TestStreamAdapter_FunctionCalls(t *testing.T) {
3434
fn(mockResp, nil)
3535
}
3636

37-
adapter := NewStreamAdapter(iter, "test-model")
37+
adapter := NewStreamAdapter(iter, "test-model", true)
3838

3939
// Read the response
4040
resp, err := adapter.Recv()

pkg/model/provider/gemini/client.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ func (c *Client) CreateChatCompletionStream(
388388

389389
// Build a fresh client per request when using the gateway
390390
iter := client.Models.GenerateContentStream(ctx, c.ModelConfig.Model, contents, config)
391-
return NewStreamAdapter(iter, c.ModelConfig.Model), nil
391+
trackUsage := c.ModelConfig.TrackUsage == nil || *c.ModelConfig.TrackUsage
392+
return NewStreamAdapter(iter, c.ModelConfig.Model, trackUsage), nil
392393
}
393394

394395
// Rerank scores documents by relevance to the query using Gemini's structured

pkg/model/provider/oaistream/adapter.go

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,23 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
111111

112112
// Check if Usage field is present using the JSON metadata
113113
if openaiResponse.JSON.Usage.Valid() {
114-
usage := openaiResponse.Usage
115-
response.Usage = &chat.Usage{
116-
InputTokens: int(usage.PromptTokens),
117-
OutputTokens: int(usage.CompletionTokens),
118-
CachedInputTokens: 0,
119-
CachedOutputTokens: 0,
120-
ReasoningTokens: 0,
121-
}
122-
if usage.JSON.PromptTokensDetails.Valid() {
123-
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
124-
}
125-
if usage.JSON.CompletionTokensDetails.Valid() {
126-
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
114+
if a.trackUsage {
115+
usage := openaiResponse.Usage
116+
response.Usage = &chat.Usage{
117+
InputTokens: int(usage.PromptTokens),
118+
OutputTokens: int(usage.CompletionTokens),
119+
CachedInputTokens: 0,
120+
CachedOutputTokens: 0,
121+
ReasoningTokens: 0,
122+
}
123+
if usage.JSON.PromptTokensDetails.Valid() {
124+
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
125+
}
126+
if usage.JSON.CompletionTokensDetails.Valid() {
127+
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
128+
}
127129
}
130+
128131
// Use the tracked finish reason instead of hardcoding stop
129132
finishReason := a.lastFinishReason
130133
if finishReason == chat.FinishReasonNull || finishReason == "" {

0 commit comments

Comments
 (0)