diff --git a/internal-packages/llm-model-catalog/package.json b/internal-packages/llm-model-catalog/package.json index be27ce3529d..ac5cdafc0a6 100644 --- a/internal-packages/llm-model-catalog/package.json +++ b/internal-packages/llm-model-catalog/package.json @@ -9,7 +9,12 @@ "@trigger.dev/core": "workspace:*", "@trigger.dev/database": "workspace:*" }, + "devDependencies": { + "@internal/testcontainers": "workspace:*", + "vitest": "3.1.4" + }, "scripts": { + "test": "vitest --sequence.concurrent=false --no-file-parallelism", "typecheck": "tsc --noEmit", "generate": "node scripts/generate.mjs", "sync-prices": "bash scripts/sync-model-prices.sh && node scripts/generate.mjs", diff --git a/internal-packages/llm-model-catalog/src/sync.test.ts b/internal-packages/llm-model-catalog/src/sync.test.ts new file mode 100644 index 00000000000..d2138564ef6 --- /dev/null +++ b/internal-packages/llm-model-catalog/src/sync.test.ts @@ -0,0 +1,203 @@ +import type { PrismaClient } from "@trigger.dev/database"; +import { postgresTest } from "@internal/testcontainers"; +import { generateFriendlyId } from "@trigger.dev/core/v3/isomorphic"; +import { describe, expect } from "vitest"; +import { defaultModelPrices } from "./defaultPrices.js"; +import { modelCatalog } from "./modelCatalog.js"; +import { syncLlmCatalog } from "./sync.js"; + +function getGpt4oDefinition() { + const def = defaultModelPrices.find((m) => m.modelName === "gpt-4o"); + if (def === undefined) { + throw new Error("expected gpt-4o in defaultModelPrices"); + } + return def; +} + +const gpt4oDef = getGpt4oDefinition(); + +function getGeminiProDefinition() { + const def = defaultModelPrices.find((m) => m.modelName === "gemini-pro"); + if (def === undefined) { + throw new Error("expected gemini-pro in defaultModelPrices"); + } + return def; +} + +const geminiProDef = getGeminiProDefinition(); + +/** If sync used `catalog?.baseModelName ?? existing.baseModelName`, sync would keep this string instead of clearing to null. */ +const STALE_BASE_MODEL_NAME = "wrong-base-model-sentinel"; + +const STALE_INPUT_PRICE = 0.099; +const STALE_OUTPUT_PRICE = 0.088; + +async function createGpt4oWithStalePricing( + prisma: PrismaClient, + source: "default" | "admin" +) { + const model = await prisma.llmModel.create({ + data: { + friendlyId: generateFriendlyId("llm_model"), + projectId: null, + modelName: gpt4oDef.modelName, + matchPattern: "^stale-pattern$", + startDate: gpt4oDef.startDate ? new Date(gpt4oDef.startDate) : null, + source, + provider: "stale-provider", + description: "stale description", + contextWindow: 111, + maxOutputTokens: 222, + capabilities: ["stale-cap"], + isHidden: true, + baseModelName: "stale-base", + }, + }); + + await prisma.llmPricingTier.create({ + data: { + modelId: model.id, + name: "Standard", + isDefault: true, + priority: 0, + conditions: [], + prices: { + create: [ + { modelId: model.id, usageType: "input", price: STALE_INPUT_PRICE }, + { modelId: model.id, usageType: "output", price: STALE_OUTPUT_PRICE }, + ], + }, + }, + }); + + return model; +} + +async function createGeminiProWithStaleBaseModelName(prisma: PrismaClient) { + const catalogEntry = modelCatalog[geminiProDef.modelName]; + expect(catalogEntry).toBeDefined(); + expect(catalogEntry.baseModelName).toBeNull(); + + const model = await prisma.llmModel.create({ + data: { + friendlyId: generateFriendlyId("llm_model"), + projectId: null, + modelName: geminiProDef.modelName, + matchPattern: "^stale-gemini-pattern$", + startDate: geminiProDef.startDate ? new Date(geminiProDef.startDate) : null, + source: "default", + provider: "stale-provider", + description: "stale description", + contextWindow: 111, + maxOutputTokens: 222, + capabilities: ["stale-cap"], + isHidden: true, + baseModelName: STALE_BASE_MODEL_NAME, + }, + }); + + const tier = geminiProDef.pricingTiers[0]; + await prisma.llmPricingTier.create({ + data: { + modelId: model.id, + name: tier.name, + isDefault: tier.isDefault, + priority: tier.priority, + conditions: tier.conditions, + prices: { + create: Object.entries(tier.prices).map(([usageType, price]) => ({ + modelId: model.id, + usageType, + price, + })), + }, + }, + }); + + return model; +} + +async function loadGpt4oWithTiers(prisma: PrismaClient) { + return prisma.llmModel.findFirst({ + where: { projectId: null, modelName: gpt4oDef.modelName }, + include: { + pricingTiers: { + include: { prices: true }, + orderBy: { priority: "asc" }, + }, + }, + }); +} + +function expectBundledGpt4oPricing(model: NonNullable>>) { + expect(model.matchPattern).toBe(gpt4oDef.matchPattern); + expect(model.pricingTiers).toHaveLength(gpt4oDef.pricingTiers.length); + + const dbTier = model.pricingTiers[0]; + const defTier = gpt4oDef.pricingTiers[0]; + expect(dbTier.name).toBe(defTier.name); + expect(dbTier.isDefault).toBe(defTier.isDefault); + expect(dbTier.priority).toBe(defTier.priority); + + const priceByType = new Map(dbTier.prices.map((p) => [p.usageType, Number(p.price)])); + for (const [usageType, expected] of Object.entries(defTier.prices)) { + expect(priceByType.get(usageType)).toBeCloseTo(expected, 12); + } + expect(priceByType.size).toBe(Object.keys(defTier.prices).length); +} + +describe("syncLlmCatalog", () => { + postgresTest( + "rebuilds gpt-4o pricing tiers from bundled defaults when source is default", + async ({ prisma }) => { + await createGpt4oWithStalePricing(prisma, "default"); + + const result = await syncLlmCatalog(prisma); + + expect(result.modelsUpdated).toBe(1); + expect(result.modelsSkipped).toBe(defaultModelPrices.length - 1); + + const after = await loadGpt4oWithTiers(prisma); + expect(after).not.toBeNull(); + expectBundledGpt4oPricing(after!); + } + ); + + postgresTest( + "does not replace pricing tiers when model source is not default", + async ({ prisma }) => { + await createGpt4oWithStalePricing(prisma, "admin"); + + const result = await syncLlmCatalog(prisma); + + expect(result.modelsUpdated).toBe(0); + expect(result.modelsSkipped).toBeGreaterThanOrEqual(1); + + const after = await loadGpt4oWithTiers(prisma); + expect(after).not.toBeNull(); + expect(after!.matchPattern).toBe("^stale-pattern$"); + expect(after!.pricingTiers).toHaveLength(1); + const prices = after!.pricingTiers[0].prices; + const input = prices.find((p) => p.usageType === "input"); + const output = prices.find((p) => p.usageType === "output"); + expect(Number(input?.price)).toBeCloseTo(STALE_INPUT_PRICE, 12); + expect(Number(output?.price)).toBeCloseTo(STALE_OUTPUT_PRICE, 12); + expect(prices).toHaveLength(2); + } + ); + + postgresTest( + "clears baseModelName when bundled catalog has null (regression for nullish-coalescing merge)", + async ({ prisma }) => { + await createGeminiProWithStaleBaseModelName(prisma); + + await syncLlmCatalog(prisma); + + const after = await prisma.llmModel.findFirst({ + where: { projectId: null, modelName: geminiProDef.modelName }, + }); + expect(after).not.toBeNull(); + expect(after!.baseModelName).toBeNull(); + } + ); +}); diff --git a/internal-packages/llm-model-catalog/src/sync.ts b/internal-packages/llm-model-catalog/src/sync.ts index b600e39a692..aa0611dbda2 100644 --- a/internal-packages/llm-model-catalog/src/sync.ts +++ b/internal-packages/llm-model-catalog/src/sync.ts @@ -1,6 +1,27 @@ -import type { PrismaClient } from "@trigger.dev/database"; +import type { Prisma, PrismaClient } from "@trigger.dev/database"; import { defaultModelPrices } from "./defaultPrices.js"; import { modelCatalog } from "./modelCatalog.js"; +import type { DefaultModelDefinition } from "./types.js"; + +function pricingTierCreateData( + modelId: string, + tier: DefaultModelDefinition["pricingTiers"][number] +): Prisma.LlmPricingTierUncheckedCreateInput { + return { + modelId, + name: tier.name, + isDefault: tier.isDefault, + priority: tier.priority, + conditions: tier.conditions, + prices: { + create: Object.entries(tier.prices).map(([usageType, price]) => ({ + modelId, + usageType, + price, + })), + }, + }; +} export async function syncLlmCatalog(prisma: PrismaClient): Promise<{ modelsUpdated: number; @@ -31,24 +52,49 @@ export async function syncLlmCatalog(prisma: PrismaClient): Promise<{ const catalog = modelCatalog[modelDef.modelName]; - await prisma.llmModel.update({ - where: { id: existing.id }, - data: { - // Update match pattern and start date from Langfuse (may have changed) - matchPattern: modelDef.matchPattern, - startDate: modelDef.startDate ? new Date(modelDef.startDate) : null, - // Update catalog metadata - provider: catalog?.provider ?? existing.provider, - description: catalog?.description ?? existing.description, - contextWindow: catalog?.contextWindow ?? existing.contextWindow, - maxOutputTokens: catalog?.maxOutputTokens ?? existing.maxOutputTokens, - capabilities: catalog?.capabilities ?? existing.capabilities, - isHidden: catalog?.isHidden ?? existing.isHidden, - baseModelName: catalog?.baseModelName ?? existing.baseModelName, - }, + const applied = await prisma.$transaction(async (tx) => { + const updateResult = await tx.llmModel.updateMany({ + where: { id: existing.id, source: "default" }, + data: { + // Update match pattern and start date from Langfuse (may have changed) + matchPattern: modelDef.matchPattern, + startDate: modelDef.startDate ? new Date(modelDef.startDate) : null, + // Update catalog metadata + provider: catalog?.provider ?? existing.provider, + description: catalog?.description ?? existing.description, + contextWindow: + catalog?.contextWindow === undefined ? existing.contextWindow : catalog.contextWindow, + maxOutputTokens: + catalog?.maxOutputTokens === undefined + ? existing.maxOutputTokens + : catalog.maxOutputTokens, + capabilities: catalog?.capabilities ?? existing.capabilities, + isHidden: catalog?.isHidden ?? existing.isHidden, + baseModelName: + catalog?.baseModelName === undefined + ? existing.baseModelName + : catalog.baseModelName, + }, + }); + + if (updateResult.count !== 1) { + return false; + } + + await tx.llmPricingTier.deleteMany({ where: { modelId: existing.id } }); + + for (const tier of modelDef.pricingTiers) { + await tx.llmPricingTier.create({ + data: pricingTierCreateData(existing.id, tier), + }); + } + + return true; }); - modelsUpdated++; + if (applied) { + modelsUpdated++; + } } return { modelsUpdated, modelsSkipped }; diff --git a/internal-packages/llm-model-catalog/tsconfig.json b/internal-packages/llm-model-catalog/tsconfig.json index c64cf33133b..7c980007eef 100644 --- a/internal-packages/llm-model-catalog/tsconfig.json +++ b/internal-packages/llm-model-catalog/tsconfig.json @@ -15,5 +15,5 @@ "strict": true, "resolveJsonModule": true }, - "exclude": ["node_modules"] + "exclude": ["node_modules", "**/*.test.ts"] } diff --git a/internal-packages/llm-model-catalog/vitest.config.ts b/internal-packages/llm-model-catalog/vitest.config.ts new file mode 100644 index 00000000000..474961216bf --- /dev/null +++ b/internal-packages/llm-model-catalog/vitest.config.ts @@ -0,0 +1,16 @@ +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + test: { + include: ["**/*.test.ts"], + globals: true, + isolate: true, + fileParallelism: false, + poolOptions: { + threads: { + singleThread: true, + }, + }, + testTimeout: 120_000, + }, +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 192a5747f2a..03c1bf16786 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1142,6 +1142,13 @@ importers: '@trigger.dev/database': specifier: workspace:* version: link:../database + devDependencies: + '@internal/testcontainers': + specifier: workspace:* + version: link:../testcontainers + vitest: + specifier: 3.1.4 + version: 3.1.4(@types/debug@4.1.12)(@types/node@20.14.14)(lightningcss@1.29.2)(terser@5.44.1) internal-packages/otlp-importer: dependencies: @@ -42268,7 +42275,7 @@ snapshots: vite-node@3.1.4(@types/node@20.14.14)(lightningcss@1.29.2)(terser@5.44.1): dependencies: cac: 6.7.14 - debug: 4.4.1 + debug: 4.4.3(supports-color@10.0.0) es-module-lexer: 1.7.0 pathe: 2.0.3 vite: 5.4.21(@types/node@20.14.14)(lightningcss@1.29.2)(terser@5.44.1)