Skip to content

Commit 42bd35a

Browse files
authored
Merge pull request #971 from dgageot/fix-pricing
Fix cost calculation
2 parents 667692d + 02b53da commit 42bd35a

19 files changed

Lines changed: 78 additions & 85 deletions

File tree

pkg/api/types.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ type SessionsResponse struct {
119119
Title string `json:"title"`
120120
CreatedAt string `json:"created_at"`
121121
NumMessages int `json:"num_messages"`
122-
InputTokens int `json:"input_tokens"`
123-
OutputTokens int `json:"output_tokens"`
122+
InputTokens int64 `json:"input_tokens"`
123+
OutputTokens int64 `json:"output_tokens"`
124124
WorkingDir string `json:"working_dir,omitempty"`
125125
}
126126

@@ -131,8 +131,8 @@ type SessionResponse struct {
131131
Messages []session.Message `json:"messages,omitempty"`
132132
CreatedAt time.Time `json:"created_at"`
133133
ToolsApproved bool `json:"tools_approved"`
134-
InputTokens int `json:"input_tokens"`
135-
OutputTokens int `json:"output_tokens"`
134+
InputTokens int64 `json:"input_tokens"`
135+
OutputTokens int64 `json:"output_tokens"`
136136
WorkingDir string `json:"working_dir,omitempty"`
137137
Pagination *PaginationMetadata `json:"pagination,omitempty"`
138138
}

pkg/chat/chat.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,11 @@ type MessageStreamResponse struct {
118118
}
119119

120120
type Usage struct {
121-
InputTokens int `json:"input_tokens"`
122-
OutputTokens int `json:"output_tokens"`
123-
CachedInputTokens int `json:"cached_input_tokens"`
124-
CachedOutputTokens int `json:"cached_output_tokens"`
125-
ReasoningTokens int `json:"reasoning_tokens,omitempty"`
121+
InputTokens int64 `json:"input_tokens"`
122+
OutputTokens int64 `json:"output_tokens"`
123+
CachedInputTokens int64 `json:"cached_input_tokens"`
124+
CacheWriteTokens int64 `json:"cached_output_tokens"`
125+
ReasoningTokens int64 `json:"reasoning_tokens,omitempty"`
126126
}
127127

128128
// MessageStream interface represents a stream of chat completions

pkg/model/provider/anthropic/adapter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
151151
case anthropic.MessageDeltaEvent:
152152
if a.trackUsage {
153153
response.Usage = &chat.Usage{
154-
InputTokens: int(eventVariant.Usage.InputTokens),
155-
OutputTokens: int(eventVariant.Usage.OutputTokens),
156-
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
157-
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
154+
InputTokens: eventVariant.Usage.InputTokens,
155+
OutputTokens: eventVariant.Usage.OutputTokens,
156+
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
157+
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
158158
}
159159
}
160160
case anthropic.MessageStopEvent:

pkg/model/provider/anthropic/beta_adapter.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,10 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
112112
}
113113
case anthropic.BetaRawMessageDeltaEvent:
114114
response.Usage = &chat.Usage{
115-
InputTokens: int(eventVariant.Usage.InputTokens),
116-
OutputTokens: int(eventVariant.Usage.OutputTokens),
117-
CachedInputTokens: int(eventVariant.Usage.CacheReadInputTokens),
118-
CachedOutputTokens: int(eventVariant.Usage.CacheCreationInputTokens),
115+
InputTokens: eventVariant.Usage.InputTokens,
116+
OutputTokens: eventVariant.Usage.OutputTokens,
117+
CachedInputTokens: eventVariant.Usage.CacheReadInputTokens,
118+
CacheWriteTokens: eventVariant.Usage.CacheCreationInputTokens,
119119
}
120120
case anthropic.BetaRawMessageStopEvent:
121121
if a.toolCall {

pkg/model/provider/base/base.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ func (c *Config) BaseConfig() Config {
2626
// EmbeddingResult contains the embedding and usage information
2727
type EmbeddingResult struct {
2828
Embedding []float64
29-
InputTokens int
30-
TotalTokens int
29+
InputTokens int64
30+
TotalTokens int64
3131
Cost float64
3232
}
3333

3434
// BatchEmbeddingResult contains multiple embeddings and usage information
3535
type BatchEmbeddingResult struct {
3636
Embeddings [][]float64
37-
InputTokens int
38-
TotalTokens int
37+
InputTokens int64
38+
TotalTokens int64
3939
Cost float64
4040
}

pkg/model/provider/dmr/client.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,8 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
634634
copy(embedding, embedding32)
635635

636636
// Extract usage information
637-
inputTokens := int(response.Usage.PromptTokens)
638-
totalTokens := int(response.Usage.TotalTokens)
637+
inputTokens := response.Usage.PromptTokens
638+
totalTokens := response.Usage.TotalTokens
639639

640640
// DMR is local/free, so cost is 0
641641
cost := 0.0
@@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
657657
func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) {
658658
if len(texts) == 0 {
659659
return &base.BatchEmbeddingResult{
660-
Embeddings: [][]float64{},
661-
InputTokens: 0,
662-
TotalTokens: 0,
663-
Cost: 0,
660+
Embeddings: [][]float64{},
664661
}, nil
665662
}
666663

@@ -693,8 +690,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas
693690
}
694691

695692
// Extract usage information
696-
inputTokens := int(response.Usage.PromptTokens)
697-
totalTokens := int(response.Usage.TotalTokens)
693+
inputTokens := response.Usage.PromptTokens
694+
totalTokens := response.Usage.TotalTokens
698695

699696
// DMR is local/free, so cost is 0
700697
cost := 0.0

pkg/model/provider/gemini/adapter.go

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,10 @@ func (g *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
177177
// Handle token usage if present
178178
if res.resp.UsageMetadata != nil && g.trackUsage {
179179
resp.Usage = &chat.Usage{
180-
InputTokens: int(res.resp.UsageMetadata.PromptTokenCount),
181-
OutputTokens: int(res.resp.UsageMetadata.CandidatesTokenCount),
182-
CachedInputTokens: int(res.resp.UsageMetadata.CachedContentTokenCount),
183-
CachedOutputTokens: 0, // Gemini doesn't provide cached output tokens
184-
ReasoningTokens: int(res.resp.UsageMetadata.ThoughtsTokenCount),
180+
InputTokens: int64(res.resp.UsageMetadata.PromptTokenCount),
181+
OutputTokens: int64(res.resp.UsageMetadata.CandidatesTokenCount),
182+
CachedInputTokens: int64(res.resp.UsageMetadata.CachedContentTokenCount),
183+
ReasoningTokens: int64(res.resp.UsageMetadata.ThoughtsTokenCount),
185184
}
186185
}
187186

pkg/model/provider/oaistream/adapter.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,17 +114,15 @@ func (a *StreamAdapter) Recv() (chat.MessageStreamResponse, error) {
114114
if a.trackUsage {
115115
usage := openaiResponse.Usage
116116
response.Usage = &chat.Usage{
117-
InputTokens: int(usage.PromptTokens),
118-
OutputTokens: int(usage.CompletionTokens),
119-
CachedInputTokens: 0,
120-
CachedOutputTokens: 0,
121-
ReasoningTokens: 0,
117+
InputTokens: usage.PromptTokens,
118+
OutputTokens: usage.CompletionTokens,
122119
}
123120
if usage.JSON.PromptTokensDetails.Valid() {
124-
response.Usage.CachedInputTokens = int(usage.PromptTokensDetails.CachedTokens)
121+
response.Usage.CachedInputTokens = usage.PromptTokensDetails.CachedTokens
122+
response.Usage.InputTokens -= usage.PromptTokensDetails.CachedTokens
125123
}
126124
if usage.JSON.CompletionTokensDetails.Valid() {
127-
response.Usage.ReasoningTokens = int(usage.CompletionTokensDetails.ReasoningTokens)
125+
response.Usage.ReasoningTokens = usage.CompletionTokensDetails.ReasoningTokens
128126
}
129127
}
130128

pkg/model/provider/openai/client.go

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,10 +657,7 @@ func (c *Client) CreateEmbedding(ctx context.Context, text string) (*base.Embedd
657657
func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*base.BatchEmbeddingResult, error) {
658658
if len(texts) == 0 {
659659
return &base.BatchEmbeddingResult{
660-
Embeddings: [][]float64{},
661-
InputTokens: 0,
662-
TotalTokens: 0,
663-
Cost: 0,
660+
Embeddings: [][]float64{},
664661
}, nil
665662
}
666663

@@ -704,8 +701,8 @@ func (c *Client) CreateBatchEmbedding(ctx context.Context, texts []string) (*bas
704701
}
705702

706703
// Extract usage information
707-
inputTokens := int(response.Usage.PromptTokens)
708-
totalTokens := int(response.Usage.TotalTokens)
704+
inputTokens := response.Usage.PromptTokens
705+
totalTokens := response.Usage.TotalTokens
709706

710707
// Cost calculation is handled at the strategy level using models.dev pricing
711708
// Provider just returns token counts

pkg/model/provider/openai/response_stream.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,9 @@ func (a *ResponseStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
210210
u := event.Response.Usage
211211
if u.TotalTokens > 0 {
212212
response.Usage = &chat.Usage{
213-
InputTokens: int(u.InputTokens),
214-
OutputTokens: int(u.OutputTokens),
213+
InputTokens: u.InputTokens - u.InputTokensDetails.CachedTokens,
214+
OutputTokens: u.OutputTokens,
215+
CachedInputTokens: u.InputTokensDetails.CachedTokens,
215216
}
216217
}
217218
// Check if there were any tool calls in the output

0 commit comments

Comments
 (0)