diff --git a/usage-service/internal/httpapi/server.go b/usage-service/internal/httpapi/server.go index 4634c0d89..5c345123d 100644 --- a/usage-service/internal/httpapi/server.go +++ b/usage-service/internal/httpapi/server.go @@ -883,40 +883,45 @@ func (s *Server) handleUsage(w http.ResponseWriter, r *http.Request) { if !s.authorizeIfConfigured(w, r) { return } + path := strings.TrimRight(r.URL.Path, "/") + if path == "" { + path = "/" + } switch r.Method { case http.MethodGet: - if strings.HasSuffix(r.URL.Path, "/export") { + switch path { + case "/v0/management/usage/export": s.handleUsageExport(w, r) return - } - if strings.HasSuffix(r.URL.Path, "/summary") { + case "/v0/management/usage/summary": s.handleUsageSummary(w, r) return - } - if strings.HasSuffix(r.URL.Path, "/accounts") { + case "/v0/management/usage/accounts": s.handleUsageBreakdownPage(w, r, store.UsageBreakdownAccounts) return - } - if strings.HasSuffix(r.URL.Path, "/api-keys") { + case "/v0/management/usage/api-keys": s.handleUsageBreakdownPage(w, r, store.UsageBreakdownAPIKeys) return - } - if strings.HasSuffix(r.URL.Path, "/realtime") { + case "/v0/management/usage/realtime": s.handleUsageBreakdownPage(w, r, store.UsageBreakdownRealtime) return - } - if strings.HasSuffix(r.URL.Path, "/models") { + case "/v0/management/usage/models": s.handleUsageBreakdownPage(w, r, store.UsageBreakdownModels) return - } - events, err := s.store.RecentEvents(r.Context(), s.cfg.QueryLimit) - if err != nil { - writeError(w, http.StatusInternalServerError, err) + case "/v0/management/usage": + events, err := s.store.RecentEvents(r.Context(), s.cfg.QueryLimit) + if err != nil { + writeError(w, http.StatusInternalServerError, err) + return + } + writeJSON(w, http.StatusOK, usage.BuildPayload(events)) + return + default: + http.NotFound(w, r) return } - writeJSON(w, http.StatusOK, usage.BuildPayload(events)) case http.MethodPost: - if strings.HasSuffix(r.URL.Path, "/import") { + if path == "/v0/management/usage/import" { s.handleUsageImport(w, r) return } diff --git a/usage-service/internal/httpapi/server_test.go b/usage-service/internal/httpapi/server_test.go index d8d65fa1c..a1e609e65 100644 --- a/usage-service/internal/httpapi/server_test.go +++ b/usage-service/internal/httpapi/server_test.go @@ -345,6 +345,33 @@ func TestUsageBreakdownPageEndpointsReturnPagination(t *testing.T) { } } +func TestUsageEndpointsUseExactPathsWithTrailingSlashNormalization(t *testing.T) { + handler := newTestHandler(t, "http://example.test", true) + + req := httptest.NewRequest(http.MethodGet, "/v0/management/usage/summary/", nil) + req.Header.Set("Authorization", "Bearer management-key") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("summary trailing slash status = %d, body = %s", rr.Code, rr.Body.String()) + } + + for _, path := range []string{ + "/v0/management/usage/not-summary/summary", + "/v0/management/usage/accounts-extra", + } { + t.Run(path, func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, path, nil) + req.Header.Set("Authorization", "Bearer management-key") + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + if rr.Code != http.StatusNotFound { + t.Fatalf("status = %d, body = %s", rr.Code, rr.Body.String()) + } + }) + } +} + func TestUsageBreakdownPageRejectsUnsafePageFilters(t *testing.T) { handler := newTestHandler(t, "http://example.test", true) for _, path := range []string{ diff --git a/usage-service/internal/store/store.go b/usage-service/internal/store/store.go index 65e1ba7ec..995d6b432 100644 --- a/usage-service/internal/store/store.go +++ b/usage-service/internal/store/store.go @@ -3,6 +3,7 @@ package store import ( "context" "database/sql" + "database/sql/driver" "encoding/json" "errors" "fmt" @@ -14,11 +15,34 @@ import ( "strings" "time" - _ "modernc.org/sqlite" + sqlite "modernc.org/sqlite" "github.com/seakee/cpa-manager/usage-service/internal/usage" ) +func init() { + sqlite.MustRegisterDeterministicScalarFunction( + "usage_last_path_segment", + 1, + func(_ *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + if len(args) == 0 || args[0] == nil { + return "", nil + } + return usageLastPathSegment(fmt.Sprint(args[0])), nil + }, + ) + sqlite.MustRegisterDeterministicScalarFunction( + "usage_strip_model_date_suffix", + 1, + func(_ *sqlite.FunctionContext, args []driver.Value) (driver.Value, error) { + if len(args) == 0 || args[0] == nil { + return "", nil + } + return usageStripModelDateSuffix(fmt.Sprint(args[0])), nil + }, + ) +} + type Setup struct { CPAUpstreamURL string `json:"cpaBaseUrl"` ManagementKey string `json:"managementKey,omitempty"` @@ -1399,35 +1423,360 @@ func (s *Store) usageAPIKeyFacet(ctx context.Context, whereClause string, args [ func (s *Store) UsageBreakdownPage(ctx context.Context, kind UsageBreakdownKind, filter UsageSummaryFilter, pageFilter UsagePageFilter) (UsagePage, error) { page, pageSize := normalizeUsagePageFilter(pageFilter) switch kind { + case UsageBreakdownAccounts, UsageBreakdownAPIKeys: + return s.usageGroupedBreakdownPage(ctx, kind, filter, pageFilter, page, pageSize) case UsageBreakdownRealtime: return s.usageRealtimePage(ctx, filter, page, pageSize) case UsageBreakdownModels: return s.usageModelPage(ctx, filter, page, pageSize) } + return UsagePage{}, fmt.Errorf("unknown usage breakdown kind %q", kind) +} - summary, err := s.usageSummary(ctx, filter, true) +func (s *Store) usageGroupedBreakdownPage(ctx context.Context, kind UsageBreakdownKind, filter UsageSummaryFilter, pageFilter UsagePageFilter, page int, pageSize int) (UsagePage, error) { + whereClause, args := filter.whereClause() + groupExpr := usageBreakdownGroupKeySQL(kind, "filtered.") + var totalItems int64 + if err := s.db.QueryRowContext(ctx, `with filtered as ( + select * from usage_events`+whereClause+` + ) + select count(*) from ( + select 1 from filtered group by `+groupExpr+` + )`, args...).Scan(&totalItems); err != nil { + return UsagePage{}, err + } + + groupKeys, err := s.usageBreakdownPageGroupKeys(ctx, kind, whereClause, args, pageFilter, page, pageSize) + if err != nil { + return UsagePage{}, err + } + if len(groupKeys) == 0 { + return UsagePage{ + Page: page, + PageSize: pageSize, + TotalItems: totalItems, + Usage: usage.Payload{APIs: map[string]*usage.APIAggregate{}}, + Items: []UsageBreakdownPageItem{}, + }, nil + } + + details, err := s.usageBreakdownDetailsForGroups(ctx, kind, whereClause, args, groupKeys) if err != nil { return UsagePage{}, err } + groups := buildBreakdownGroupsFromDetails(kind, groupKeys, details) + return UsagePage{ + Page: page, + PageSize: pageSize, + TotalItems: totalItems, + Usage: payloadSummaryFromBreakdownDetails(details), + Items: buildBreakdownPageItems(kind, groups), + }, nil +} - details := flattenUsagePayload(summary) - var prices map[string]ModelPrice - var priceIndex *usageModelPriceIndex +func (s *Store) usageBreakdownPageGroupKeys(ctx context.Context, kind UsageBreakdownKind, whereClause string, args []any, pageFilter UsagePageFilter, page int, pageSize int) ([]string, error) { + groupExpr := usageBreakdownGroupKeySQL(kind, "priced.") + pricedCTE := `, priced as ( + select filtered.*, 0 as event_cost from filtered + )` if normalizeUsageSortKey(pageFilter.SortKey) == "totalCost" { - prices, err = s.LoadModelPrices(ctx) - if err != nil { - return UsagePage{}, err + pricedCTE = usageBreakdownCostCTEs() + } + query := `with filtered as ( + select * from usage_events` + whereClause + ` + ) + ` + pricedCTE + ` + select group_key from ( + select + ` + groupExpr + ` as group_key, + count(*) as total_calls, + coalesce(sum(case when priced.failed = 0 then 1 else 0 end), 0) as success_calls, + coalesce(sum(case when priced.failed != 0 then 1 else 0 end), 0) as failure_calls, + coalesce(sum(priced.input_tokens), 0) as input_tokens, + coalesce(sum(priced.output_tokens), 0) as output_tokens, + coalesce(sum(max(priced.cached_tokens, priced.cache_tokens)), 0) as cached_tokens, + coalesce(sum(priced.total_tokens), 0) as total_tokens, + max(priced.timestamp_ms) as latest_timestamp, + coalesce(sum(priced.event_cost), 0) as total_cost + from priced + group by ` + groupExpr + ` + ) + order by ` + usageBreakdownOrderBySQL(pageFilter) + ` + limit ? offset ?` + rows, err := s.db.QueryContext(ctx, query, append(args, pageSize, (page-1)*pageSize)...) + if err != nil { + return nil, err + } + defer rows.Close() + + keys := []string{} + for rows.Next() { + var key string + if err := rows.Scan(&key); err != nil { + return nil, err + } + keys = append(keys, key) + } + if err := rows.Err(); err != nil { + return nil, err + } + return keys, nil +} + +func (s *Store) usageBreakdownDetailsForGroups(ctx context.Context, kind UsageBreakdownKind, whereClause string, args []any, groupKeys []string) ([]usageBreakdownDetail, error) { + groupExpr := usageBreakdownGroupKeySQL(kind, "filtered.") + placeholders := strings.TrimRight(strings.Repeat("?,", len(groupKeys)), ",") + query := `with filtered as ( + select * from usage_events` + whereClause + ` + ) + select + coalesce(nullif(endpoint, ''), '-'), + coalesce(nullif(model, ''), '-'), + coalesce(source, ''), + coalesce(auth_index, ''), + coalesce(api_key_hash, ''), + coalesce(account_snapshot, ''), + coalesce(auth_label_snapshot, ''), + coalesce(auth_file_snapshot, ''), + coalesce(auth_provider_snapshot, ''), + coalesce(auth_project_id_snapshot, ''), + coalesce(resolved_model, ''), + failed, + max(timestamp_ms), + coalesce(max(auth_snapshot_at_ms), 0), + count(*), + coalesce(sum(case when failed = 0 then 1 else 0 end), 0), + coalesce(sum(case when failed != 0 then 1 else 0 end), 0), + coalesce(sum(input_tokens), 0), + coalesce(sum(output_tokens), 0), + coalesce(sum(reasoning_tokens), 0), + coalesce(sum(cached_tokens), 0), + coalesce(sum(cache_tokens), 0), + coalesce(sum(total_tokens), 0), + coalesce(sum(latency_ms), 0), + count(latency_ms) + from filtered + where ` + groupExpr + ` in (` + placeholders + `) + group by ` + groupExpr + `, coalesce(nullif(endpoint, ''), '-'), coalesce(nullif(model, ''), '-'), + coalesce(source, ''), coalesce(auth_index, ''), coalesce(api_key_hash, ''), + coalesce(account_snapshot, ''), coalesce(auth_label_snapshot, ''), + coalesce(auth_file_snapshot, ''), coalesce(auth_provider_snapshot, ''), + coalesce(auth_project_id_snapshot, ''), coalesce(resolved_model, ''), failed + order by ` + groupExpr + `, max(timestamp_ms) desc` + queryArgs := append([]any{}, args...) + for _, key := range groupKeys { + queryArgs = append(queryArgs, key) + } + rows, err := s.db.QueryContext(ctx, query, queryArgs...) + if err != nil { + return nil, err + } + defer rows.Close() + + details := []usageBreakdownDetail{} + for rows.Next() { + var endpoint, model, source, authIndex, apiKeyHash, accountSnapshot, authLabelSnapshot, authFileSnapshot, authProviderSnapshot, authProjectIDSnapshot, resolvedModel string + var failed int + var latestTimestampMS, authSnapshotAtMS int64 + var latencySumMS, latencyCount int64 + detail := usage.Detail{} + if err := rows.Scan( + &endpoint, + &model, + &source, + &authIndex, + &apiKeyHash, + &accountSnapshot, + &authLabelSnapshot, + &authFileSnapshot, + &authProviderSnapshot, + &authProjectIDSnapshot, + &resolvedModel, + &failed, + &latestTimestampMS, + &authSnapshotAtMS, + &detail.RequestCount, + &detail.SuccessCount, + &detail.FailureCount, + &detail.Tokens.InputTokens, + &detail.Tokens.OutputTokens, + &detail.Tokens.ReasoningTokens, + &detail.Tokens.CachedTokens, + &detail.Tokens.CacheTokens, + &detail.Tokens.TotalTokens, + &latencySumMS, + &latencyCount, + ); err != nil { + return nil, err + } + detail.Timestamp = time.UnixMilli(latestTimestampMS).UTC().Format(time.RFC3339Nano) + detail.Source = source + detail.AuthIndex = authIndex + detail.APIKeyHash = apiKeyHash + detail.AccountSnapshot = accountSnapshot + detail.AuthLabelSnapshot = authLabelSnapshot + detail.AuthFileSnapshot = authFileSnapshot + detail.AuthProviderSnapshot = authProviderSnapshot + detail.AuthProjectIDSnapshot = authProjectIDSnapshot + detail.AuthSnapshotAtMS = authSnapshotAtMS + detail.ResolvedModel = resolvedModel + detail.Failed = failed != 0 + detail.LatencySumMS = latencySumMS + detail.LatencyCount = latencyCount + if latencyCount > 0 { + averageLatency := latencySumMS / latencyCount + detail.LatencyMS = &averageLatency } - priceIndex = buildUsageModelPriceIndex(prices) + details = append(details, usageBreakdownDetail{ + Endpoint: endpoint, + Model: model, + TimestampMS: latestTimestampMS, + Detail: detail, + }) } + if err := rows.Err(); err != nil { + return nil, err + } + return details, nil +} + +func usageBreakdownGroupKeySQL(kind UsageBreakdownKind, prefix string) string { switch kind { case UsageBreakdownAccounts: - return buildGroupedUsagePage(kind, details, page, pageSize, pageFilter, accountBreakdownKey, prices, priceIndex), nil + return `coalesce(nullif(` + prefix + `account_snapshot, ''), nullif(` + prefix + `auth_label_snapshot, ''), nullif(` + prefix + `source, ''), nullif(` + prefix + `auth_index, ''), '-')` case UsageBreakdownAPIKeys: - return buildGroupedUsagePage(kind, details, page, pageSize, pageFilter, apiKeyBreakdownKey, prices, priceIndex), nil + return `case + when trim(coalesce(` + prefix + `api_key_hash, '')) != '' then lower(` + prefix + `api_key_hash) + else 'unknown:' || coalesce(` + prefix + `source, '') || ':' || coalesce(` + prefix + `auth_index, '') || ':' || coalesce(` + prefix + `auth_provider_snapshot, '') + end` default: - return UsagePage{}, fmt.Errorf("unknown usage breakdown kind %q", kind) + return `'-'` + } +} + +func usageBreakdownOrderBySQL(filter UsagePageFilter) string { + sortKey := normalizeUsageSortKey(filter.SortKey) + direction := "desc" + if strings.ToLower(strings.TrimSpace(filter.SortDirection)) == "asc" { + direction = "asc" + } + sortExpr := "latest_timestamp" + switch sortKey { + case "totalCalls": + sortExpr = "total_calls" + case "successCalls": + sortExpr = "success_calls" + case "failureCalls": + sortExpr = "failure_calls" + case "successRate": + sortExpr = "case when total_calls <= 0 then 1.0 else cast(success_calls as real) / cast(total_calls as real) end" + case "totalTokens": + sortExpr = "total_tokens" + case "totalCost": + sortExpr = "total_cost" + case "inputTokens": + sortExpr = "input_tokens" + case "outputTokens": + sortExpr = "output_tokens" + case "cachedTokens": + sortExpr = "cached_tokens" + case "lastSeenAt": + sortExpr = "latest_timestamp" } + return sortExpr + " " + direction + ", latest_timestamp " + direction + ", lower(group_key) " + direction +} + +func usageBreakdownCostCTEs() string { + return `, event_price_candidates as ( + select + event_rows.id as event_id, + p.model as price_model, + p.prompt_per_1m, + p.completion_per_1m, + p.cache_per_1m, + ` + usageBreakdownPricePrioritySQL("event_rows.", "p.") + ` as match_priority + from filtered as event_rows + join model_prices p on ` + usageBreakdownPriceMatchSQL("event_rows.", "p.") + ` + ), ranked_event_prices as ( + select + event_id, + price_model, + prompt_per_1m, + completion_per_1m, + cache_per_1m, + row_number() over ( + partition by event_id + order by match_priority, length(price_model), lower(price_model) + ) as price_rank + from event_price_candidates + ), best_event_prices as ( + select event_id, prompt_per_1m, completion_per_1m, cache_per_1m + from ranked_event_prices + where price_rank = 1 + ), priced as ( + select + event_rows.*, + ` + usageBreakdownJoinedCostSQL("event_rows.", "best_event_prices.") + ` as event_cost + from filtered as event_rows + left join best_event_prices on best_event_prices.event_id = event_rows.id + )` +} + +func usageBreakdownPriceMatchSQL(eventPrefix string, pricePrefix string) string { + resolvedLower := `lower(trim(coalesce(` + eventPrefix + `resolved_model, '')))` + modelLower := `lower(trim(coalesce(` + eventPrefix + `model, '')))` + resolvedBase := `usage_last_path_segment(` + resolvedLower + `)` + modelBase := `usage_last_path_segment(` + modelLower + `)` + resolvedStripped := `usage_strip_model_date_suffix(` + resolvedBase + `)` + modelStripped := `usage_strip_model_date_suffix(` + modelBase + `)` + priceLower := `lower(trim(` + pricePrefix + `model))` + priceBase := `usage_last_path_segment(` + priceLower + `)` + priceStripped := `usage_strip_model_date_suffix(` + priceBase + `)` + matchSQL := func(candidateLower, candidateBase, candidateStripped string) string { + return `(` + candidateLower + ` != '' and ( + ` + priceLower + ` = ` + candidateLower + ` or + ` + priceBase + ` = ` + candidateBase + ` or + ` + priceBase + ` = ` + candidateStripped + ` or + ` + priceStripped + ` = ` + candidateStripped + ` or + ` + priceStripped + ` = ` + candidateBase + ` + ))` + } + return matchSQL(resolvedLower, resolvedBase, resolvedStripped) + ` or ` + matchSQL(modelLower, modelBase, modelStripped) +} + +func usageBreakdownPricePrioritySQL(eventPrefix string, pricePrefix string) string { + resolvedLower := `lower(trim(coalesce(` + eventPrefix + `resolved_model, '')))` + modelLower := `lower(trim(coalesce(` + eventPrefix + `model, '')))` + resolvedBase := `usage_last_path_segment(` + resolvedLower + `)` + modelBase := `usage_last_path_segment(` + modelLower + `)` + resolvedStripped := `usage_strip_model_date_suffix(` + resolvedBase + `)` + modelStripped := `usage_strip_model_date_suffix(` + modelBase + `)` + priceLower := `lower(trim(` + pricePrefix + `model))` + priceBase := `usage_last_path_segment(` + priceLower + `)` + priceStripped := `usage_strip_model_date_suffix(` + priceBase + `)` + prioritySQL := func(candidateLower, candidateBase, candidateStripped string, offset int) string { + return `when ` + candidateLower + ` != '' and ` + priceLower + ` = ` + candidateLower + ` then ` + fmt.Sprint(offset) + ` + when ` + candidateLower + ` != '' and ` + priceBase + ` = ` + candidateBase + ` then ` + fmt.Sprint(offset+1) + ` + when ` + candidateLower + ` != '' and ` + priceBase + ` = ` + candidateStripped + ` then ` + fmt.Sprint(offset+2) + ` + when ` + candidateLower + ` != '' and ` + priceStripped + ` = ` + candidateStripped + ` then ` + fmt.Sprint(offset+3) + ` + when ` + candidateLower + ` != '' and ` + priceStripped + ` = ` + candidateBase + ` then ` + fmt.Sprint(offset+4) + } + return `case + ` + prioritySQL(resolvedLower, resolvedBase, resolvedStripped, 0) + ` + ` + prioritySQL(modelLower, modelBase, modelStripped, 10) + ` + else 100 + end` +} + +func usageBreakdownJoinedCostSQL(eventPrefix string, pricePrefix string) string { + cachedTokens := `max(` + eventPrefix + `cached_tokens, ` + eventPrefix + `cache_tokens)` + promptTokens := `max(` + eventPrefix + `input_tokens - ` + cachedTokens + `, 0)` + return `coalesce( + (cast(` + promptTokens + ` as real) / ` + fmt.Sprint(usageTokensPerPriceUnit) + `.0 * ` + pricePrefix + `prompt_per_1m) + + (cast(max(` + cachedTokens + `, 0) as real) / ` + fmt.Sprint(usageTokensPerPriceUnit) + `.0 * ` + pricePrefix + `cache_per_1m) + + (cast(max(` + eventPrefix + `output_tokens, 0) as real) / ` + fmt.Sprint(usageTokensPerPriceUnit) + `.0 * ` + pricePrefix + `completion_per_1m), + 0 + )` } func (s *Store) usageModelPage(ctx context.Context, filter UsageSummaryFilter, page int, pageSize int) (UsagePage, error) { @@ -1661,6 +2010,42 @@ func buildGroupedUsagePage( } } +func buildBreakdownGroupsFromDetails(kind UsageBreakdownKind, groupKeys []string, details []usageBreakdownDetail) []*usageBreakdownGroup { + groups := make([]*usageBreakdownGroup, 0, len(groupKeys)) + groupMap := make(map[string]*usageBreakdownGroup, len(groupKeys)) + keyFunc := accountBreakdownKey + if kind == UsageBreakdownAPIKeys { + keyFunc = apiKeyBreakdownKey + } + for _, key := range groupKeys { + group := &usageBreakdownGroup{Key: key} + groups = append(groups, group) + groupMap[key] = group + } + for _, item := range details { + key := keyFunc(item.Detail) + group := groupMap[key] + if group == nil { + continue + } + group.Details = append(group.Details, item) + group.TotalCalls += item.Detail.RequestCount + group.SuccessCalls += item.Detail.SuccessCount + group.FailureCalls += item.Detail.FailureCount + group.InputTokens += item.Detail.Tokens.InputTokens + group.OutputTokens += item.Detail.Tokens.OutputTokens + group.CachedTokens += maxInt64(item.Detail.Tokens.CachedTokens, item.Detail.Tokens.CacheTokens) + group.TotalTokens += item.Detail.Tokens.TotalTokens + if item.TimestampMS > group.LatestTimestamp { + group.LatestTimestamp = item.TimestampMS + } + } + for _, group := range groups { + sortBreakdownDetails(group.Details) + } + return groups +} + func buildBreakdownPageItems(kind UsageBreakdownKind, groups []*usageBreakdownGroup) []UsageBreakdownPageItem { items := make([]UsageBreakdownPageItem, 0, len(groups)) for _, group := range groups {