Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions internal-packages/llm-model-catalog/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
"@trigger.dev/core": "workspace:*",
"@trigger.dev/database": "workspace:*"
},
"devDependencies": {
"@internal/testcontainers": "workspace:*",
"vitest": "3.1.4"
},
"scripts": {
"test": "vitest --sequence.concurrent=false --no-file-parallelism",
"typecheck": "tsc --noEmit",
"generate": "node scripts/generate.mjs",
"sync-prices": "bash scripts/sync-model-prices.sh && node scripts/generate.mjs",
Expand Down
203 changes: 203 additions & 0 deletions internal-packages/llm-model-catalog/src/sync.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import type { PrismaClient } from "@trigger.dev/database";
import { postgresTest } from "@internal/testcontainers";
import { generateFriendlyId } from "@trigger.dev/core/v3/isomorphic";
import { describe, expect } from "vitest";
import { defaultModelPrices } from "./defaultPrices.js";
import { modelCatalog } from "./modelCatalog.js";
import { syncLlmCatalog } from "./sync.js";

function getGpt4oDefinition() {
const def = defaultModelPrices.find((m) => m.modelName === "gpt-4o");
if (def === undefined) {
throw new Error("expected gpt-4o in defaultModelPrices");
}
return def;
}

const gpt4oDef = getGpt4oDefinition();

function getGeminiProDefinition() {
const def = defaultModelPrices.find((m) => m.modelName === "gemini-pro");
if (def === undefined) {
throw new Error("expected gemini-pro in defaultModelPrices");
}
return def;
}

const geminiProDef = getGeminiProDefinition();

/** If sync used `catalog?.baseModelName ?? existing.baseModelName`, sync would keep this string instead of clearing to null. */
const STALE_BASE_MODEL_NAME = "wrong-base-model-sentinel";

const STALE_INPUT_PRICE = 0.099;
const STALE_OUTPUT_PRICE = 0.088;

async function createGpt4oWithStalePricing(
prisma: PrismaClient,
source: "default" | "admin"
) {
const model = await prisma.llmModel.create({
data: {
friendlyId: generateFriendlyId("llm_model"),
projectId: null,
modelName: gpt4oDef.modelName,
matchPattern: "^stale-pattern$",
startDate: gpt4oDef.startDate ? new Date(gpt4oDef.startDate) : null,
source,
provider: "stale-provider",
description: "stale description",
contextWindow: 111,
maxOutputTokens: 222,
capabilities: ["stale-cap"],
isHidden: true,
baseModelName: "stale-base",
},
});

await prisma.llmPricingTier.create({
data: {
modelId: model.id,
name: "Standard",
isDefault: true,
priority: 0,
conditions: [],
prices: {
create: [
{ modelId: model.id, usageType: "input", price: STALE_INPUT_PRICE },
{ modelId: model.id, usageType: "output", price: STALE_OUTPUT_PRICE },
],
},
},
});

return model;
}

async function createGeminiProWithStaleBaseModelName(prisma: PrismaClient) {
const catalogEntry = modelCatalog[geminiProDef.modelName];
expect(catalogEntry).toBeDefined();
expect(catalogEntry.baseModelName).toBeNull();

const model = await prisma.llmModel.create({
data: {
friendlyId: generateFriendlyId("llm_model"),
projectId: null,
modelName: geminiProDef.modelName,
matchPattern: "^stale-gemini-pattern$",
startDate: geminiProDef.startDate ? new Date(geminiProDef.startDate) : null,
source: "default",
provider: "stale-provider",
description: "stale description",
contextWindow: 111,
maxOutputTokens: 222,
capabilities: ["stale-cap"],
isHidden: true,
baseModelName: STALE_BASE_MODEL_NAME,
},
});

const tier = geminiProDef.pricingTiers[0];
await prisma.llmPricingTier.create({
data: {
modelId: model.id,
name: tier.name,
isDefault: tier.isDefault,
priority: tier.priority,
conditions: tier.conditions,
prices: {
create: Object.entries(tier.prices).map(([usageType, price]) => ({
modelId: model.id,
usageType,
price,
})),
},
},
});

return model;
}

async function loadGpt4oWithTiers(prisma: PrismaClient) {
return prisma.llmModel.findFirst({
where: { projectId: null, modelName: gpt4oDef.modelName },
include: {
pricingTiers: {
include: { prices: true },
orderBy: { priority: "asc" },
},
},
});
}

function expectBundledGpt4oPricing(model: NonNullable<Awaited<ReturnType<typeof loadGpt4oWithTiers>>>) {
expect(model.matchPattern).toBe(gpt4oDef.matchPattern);
expect(model.pricingTiers).toHaveLength(gpt4oDef.pricingTiers.length);

const dbTier = model.pricingTiers[0];
const defTier = gpt4oDef.pricingTiers[0];
expect(dbTier.name).toBe(defTier.name);
expect(dbTier.isDefault).toBe(defTier.isDefault);
expect(dbTier.priority).toBe(defTier.priority);

const priceByType = new Map(dbTier.prices.map((p) => [p.usageType, Number(p.price)]));
for (const [usageType, expected] of Object.entries(defTier.prices)) {
expect(priceByType.get(usageType)).toBeCloseTo(expected, 12);
}
expect(priceByType.size).toBe(Object.keys(defTier.prices).length);
}

describe("syncLlmCatalog", () => {
postgresTest(
"rebuilds gpt-4o pricing tiers from bundled defaults when source is default",
async ({ prisma }) => {
await createGpt4oWithStalePricing(prisma, "default");

const result = await syncLlmCatalog(prisma);

expect(result.modelsUpdated).toBe(1);
expect(result.modelsSkipped).toBe(defaultModelPrices.length - 1);

const after = await loadGpt4oWithTiers(prisma);
expect(after).not.toBeNull();
expectBundledGpt4oPricing(after!);
}
);

postgresTest(
"does not replace pricing tiers when model source is not default",
async ({ prisma }) => {
await createGpt4oWithStalePricing(prisma, "admin");

const result = await syncLlmCatalog(prisma);

expect(result.modelsUpdated).toBe(0);
expect(result.modelsSkipped).toBeGreaterThanOrEqual(1);

const after = await loadGpt4oWithTiers(prisma);
expect(after).not.toBeNull();
expect(after!.matchPattern).toBe("^stale-pattern$");
expect(after!.pricingTiers).toHaveLength(1);
const prices = after!.pricingTiers[0].prices;
const input = prices.find((p) => p.usageType === "input");
const output = prices.find((p) => p.usageType === "output");
expect(Number(input?.price)).toBeCloseTo(STALE_INPUT_PRICE, 12);
expect(Number(output?.price)).toBeCloseTo(STALE_OUTPUT_PRICE, 12);
expect(prices).toHaveLength(2);
}
);

postgresTest(
"clears baseModelName when bundled catalog has null (regression for nullish-coalescing merge)",
async ({ prisma }) => {
await createGeminiProWithStaleBaseModelName(prisma);

await syncLlmCatalog(prisma);

const after = await prisma.llmModel.findFirst({
where: { projectId: null, modelName: geminiProDef.modelName },
});
expect(after).not.toBeNull();
expect(after!.baseModelName).toBeNull();
}
);
});
80 changes: 63 additions & 17 deletions internal-packages/llm-model-catalog/src/sync.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
import type { PrismaClient } from "@trigger.dev/database";
import type { Prisma, PrismaClient } from "@trigger.dev/database";
import { defaultModelPrices } from "./defaultPrices.js";
import { modelCatalog } from "./modelCatalog.js";
import type { DefaultModelDefinition } from "./types.js";

function pricingTierCreateData(
modelId: string,
tier: DefaultModelDefinition["pricingTiers"][number]
): Prisma.LlmPricingTierUncheckedCreateInput {
return {
modelId,
name: tier.name,
isDefault: tier.isDefault,
priority: tier.priority,
conditions: tier.conditions,
prices: {
create: Object.entries(tier.prices).map(([usageType, price]) => ({
modelId,
usageType,
price,
})),
},
};
}

export async function syncLlmCatalog(prisma: PrismaClient): Promise<{
modelsUpdated: number;
Expand Down Expand Up @@ -31,24 +52,49 @@ export async function syncLlmCatalog(prisma: PrismaClient): Promise<{

const catalog = modelCatalog[modelDef.modelName];

await prisma.llmModel.update({
where: { id: existing.id },
data: {
// Update match pattern and start date from Langfuse (may have changed)
matchPattern: modelDef.matchPattern,
startDate: modelDef.startDate ? new Date(modelDef.startDate) : null,
// Update catalog metadata
provider: catalog?.provider ?? existing.provider,
description: catalog?.description ?? existing.description,
contextWindow: catalog?.contextWindow ?? existing.contextWindow,
maxOutputTokens: catalog?.maxOutputTokens ?? existing.maxOutputTokens,
capabilities: catalog?.capabilities ?? existing.capabilities,
isHidden: catalog?.isHidden ?? existing.isHidden,
baseModelName: catalog?.baseModelName ?? existing.baseModelName,
},
const applied = await prisma.$transaction(async (tx) => {
const updateResult = await tx.llmModel.updateMany({
where: { id: existing.id, source: "default" },
data: {
// Update match pattern and start date from Langfuse (may have changed)
matchPattern: modelDef.matchPattern,
startDate: modelDef.startDate ? new Date(modelDef.startDate) : null,
// Update catalog metadata
provider: catalog?.provider ?? existing.provider,
description: catalog?.description ?? existing.description,
contextWindow:
catalog?.contextWindow === undefined ? existing.contextWindow : catalog.contextWindow,
maxOutputTokens:
catalog?.maxOutputTokens === undefined
? existing.maxOutputTokens
: catalog.maxOutputTokens,
capabilities: catalog?.capabilities ?? existing.capabilities,
isHidden: catalog?.isHidden ?? existing.isHidden,
baseModelName:
catalog?.baseModelName === undefined
? existing.baseModelName
: catalog.baseModelName,
},
});

if (updateResult.count !== 1) {
return false;
}

await tx.llmPricingTier.deleteMany({ where: { modelId: existing.id } });

for (const tier of modelDef.pricingTiers) {
await tx.llmPricingTier.create({
data: pricingTierCreateData(existing.id, tier),
});
}

return true;
});

modelsUpdated++;
if (applied) {
modelsUpdated++;
}
}

return { modelsUpdated, modelsSkipped };
Expand Down
2 changes: 1 addition & 1 deletion internal-packages/llm-model-catalog/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
"strict": true,
"resolveJsonModule": true
},
"exclude": ["node_modules"]
"exclude": ["node_modules", "**/*.test.ts"]
}
16 changes: 16 additions & 0 deletions internal-packages/llm-model-catalog/vitest.config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import { defineConfig } from "vitest/config";

export default defineConfig({
test: {
include: ["**/*.test.ts"],
globals: true,
isolate: true,
fileParallelism: false,
poolOptions: {
threads: {
singleThread: true,
},
},
testTimeout: 120_000,
},
});
9 changes: 8 additions & 1 deletion pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading