Skip to content

Commit 9e3b117

Browse files
fixed tool calling in playground and added playwright test (tensorzero#3032)
* fixed tool calling in playground and added playwright test * Regenerate ModelInferenceCache fixtures * added extra inference options code * removed console log * Regenerate ModelInferenceCache fixtures --------- Co-authored-by: TensorZero Bot <github-actions[bot]@users.noreply.github.com>
1 parent 1b3cbad commit 9e3b117

9 files changed

Lines changed: 268 additions & 137 deletions

File tree

ui/app/routes/api/tensorzero/inference.ts

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ export async function action({ request }: Route.ActionArgs): Promise<Response> {
2525
{ status: 400 },
2626
);
2727
}
28-
2928
if (isTensorZeroServerError(error)) {
3029
return Response.json({ error: error.message }, { status: error.status });
3130
}

ui/app/routes/api/tensorzero/inference.utils.tsx

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ import type {
77
ClientInput,
88
ClientInputMessage,
99
ClientInputMessageContent,
10+
FunctionConfig,
1011
JsonValue,
12+
Tool,
1113
} from "tensorzero-node";
1214
import type {
1315
InputMessageContent as TensorZeroContent,
@@ -292,6 +294,7 @@ interface ClickHouseDatapointActionArgs {
292294
variant: string;
293295
cache_options: CacheParamsOptions;
294296
dryrun: boolean;
297+
functionConfig: FunctionConfig;
295298
}
296299

297300
export function prepareInferenceActionRequest(
@@ -351,7 +354,12 @@ export function prepareInferenceActionRequest(
351354
// Extract tool parameters from the ClickHouse datapoint args
352355
const tool_choice = args.tool_params?.tool_choice;
353356
const parallel_tool_calls = args.tool_params?.parallel_tool_calls;
354-
const tools_available = args.tool_params?.tools_available;
357+
const additional_tools = args.tool_params?.tools_available
358+
? subtractStaticToolsFromInferenceInput(
359+
args.tool_params?.tools_available,
360+
args.functionConfig,
361+
)
362+
: null;
355363

356364
return {
357365
...baseParams,
@@ -362,7 +370,7 @@ export function prepareInferenceActionRequest(
362370
tool_choice: tool_choice || null,
363371
dryrun: true,
364372
parallel_tool_calls: parallel_tool_calls || null,
365-
additional_tools: tools_available || null,
373+
additional_tools,
366374
cache_options: args.cache_options,
367375
};
368376
} else {
@@ -589,3 +597,26 @@ function resolvedFileContentToClientFile(
589597
data: data,
590598
};
591599
}
600+
601+
/*
602+
* For both inferences and datapoints, we store a full tool config that
603+
* specifies what the model saw or could have seen at inference time for a particular example.
604+
* However, TensorZero will automatically use the tools that are currently configured for inferences.
605+
* It will also error if there are tools with duplicated names. In order to avoid this, we "subtract"
606+
* out all currently configured tools from the tools that we pass in dynamically.
607+
*/
608+
function subtractStaticToolsFromInferenceInput(
609+
datapointTools: Tool[],
610+
functionConfig: FunctionConfig,
611+
): Tool[] {
612+
if (functionConfig.type === "json") {
613+
return datapointTools;
614+
}
615+
const resultTools = [];
616+
for (const tool of datapointTools) {
617+
if (!functionConfig.tools.some((t) => t === tool.name)) {
618+
resultTools.push(tool);
619+
}
620+
}
621+
return resultTools;
622+
}

ui/app/routes/playground/DatapointPlaygroundOutput.tsx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ import { Button } from "~/components/ui/button";
77
import { CodeEditor } from "~/components/ui/code-editor";
88
import { refreshClientInference } from "./utils";
99
import type { DisplayInput } from "~/utils/clickhouse/common";
10-
import type { Datapoint, InferenceResponse } from "tensorzero-node";
10+
import type {
11+
Datapoint,
12+
FunctionConfig,
13+
InferenceResponse,
14+
} from "tensorzero-node";
1115

1216
interface DatapointPlaygroundOutputProps {
1317
datapoint: Datapoint;
@@ -21,6 +25,7 @@ interface DatapointPlaygroundOutputProps {
2125
) => void;
2226
input: DisplayInput;
2327
functionName: string;
28+
functionConfig: FunctionConfig;
2429
}
2530
const DatapointPlaygroundOutput = memo(
2631
function DatapointPlaygroundOutput({
@@ -31,6 +36,7 @@ const DatapointPlaygroundOutput = memo(
3136
input,
3237
functionName,
3338
isLoading,
39+
functionConfig,
3440
}: DatapointPlaygroundOutputProps) {
3541
const loadingIndicator = (
3642
<div className="flex min-h-[8rem] items-center justify-center">
@@ -49,6 +55,7 @@ const DatapointPlaygroundOutput = memo(
4955
datapoint,
5056
variantName,
5157
functionName,
58+
functionConfig,
5259
);
5360
}}
5461
>

ui/app/routes/playground/route.tsx

Lines changed: 76 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ import {
2121
} from "~/routes/api/tensorzero/inference.utils";
2222
import { resolveInput } from "~/utils/resolve.server";
2323
import { X } from "lucide-react";
24-
import type { Datapoint as TensorZeroDatapoint } from "tensorzero-node";
24+
import type {
25+
FunctionConfig,
26+
Datapoint as TensorZeroDatapoint,
27+
} from "tensorzero-node";
2528
import type { DisplayInput } from "~/utils/clickhouse/common";
2629
import { useCallback, useEffect, useMemo, useState } from "react";
2730
import { Button } from "~/components/ui/button";
@@ -148,26 +151,30 @@ export async function loader({ request }: Route.LoaderArgs) {
148151
datapoint: TensorZeroDatapoint,
149152
functionName: string,
150153
variantName: string,
154+
functionConfig: FunctionConfig,
151155
) => {
152-
const request = prepareInferenceActionRequest({
153-
source: "clickhouse_datapoint",
154-
input,
155-
functionName,
156-
variant: variantName,
157-
tool_params:
158-
datapoint?.type === "chat"
159-
? (datapoint.tool_params ?? undefined)
160-
: undefined,
161-
output_schema:
162-
datapoint?.type === "json" ? datapoint.output_schema : null,
163-
// The default is write_only but we do off in the playground
164-
cache_options: {
165-
max_age_s: null,
166-
enabled: "off",
167-
},
168-
dryrun: true,
156+
const request = {
157+
...prepareInferenceActionRequest({
158+
source: "clickhouse_datapoint",
159+
input,
160+
functionName,
161+
variant: variantName,
162+
tool_params:
163+
datapoint?.type === "chat"
164+
? (datapoint.tool_params ?? undefined)
165+
: undefined,
166+
output_schema:
167+
datapoint?.type === "json" ? datapoint.output_schema : null,
168+
// The default is write_only but we do off in the playground
169+
cache_options: {
170+
max_age_s: null,
171+
enabled: "off",
172+
},
173+
dryrun: true,
174+
functionConfig,
175+
}),
169176
...getExtraInferenceOptions(),
170-
});
177+
};
171178
const nativeClient = await getNativeTensorZeroClient();
172179
const inferenceResponse = await nativeClient.inference(request);
173180
return inferenceResponse;
@@ -183,7 +190,7 @@ export async function loader({ request }: Route.LoaderArgs) {
183190
for (const variant of selectedVariants) {
184191
serverInferences.set(variant, new Map());
185192
}
186-
if (datapoints && inputs && functionName) {
193+
if (datapoints && inputs && functionName && functionConfig) {
187194
for (let index = 0; index < datapoints.length; index++) {
188195
const datapoint = datapoints[index];
189196
const input = inputs[index];
@@ -192,7 +199,13 @@ export async function loader({ request }: Route.LoaderArgs) {
192199
.get(variant)
193200
?.set(
194201
datapoint.id,
195-
serverInference(input, datapoint, functionName, variant),
202+
serverInference(
203+
input,
204+
datapoint,
205+
functionName,
206+
variant,
207+
functionConfig,
208+
),
196209
);
197210
}
198211
}
@@ -248,20 +261,21 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
248261
offset,
249262
limit,
250263
} = loaderData;
264+
const functionConfig = useFunctionConfig(functionName);
265+
if (functionName && !functionConfig) {
266+
throw data(`Function config not found for function ${functionName}`, {
267+
status: 404,
268+
});
269+
}
251270
const { map, setPromise } = useClientInferences(
252271
functionName,
253272
datapoints,
254273
inputs,
255274
selectedVariants,
256275
serverInferences,
276+
functionConfig,
257277
);
258278

259-
const functionConfig = useFunctionConfig(functionName);
260-
if (functionName && !functionConfig) {
261-
throw data(`Function config not found for function ${functionName}`, {
262-
status: 404,
263-
});
264-
}
265279
const variants = functionConfig?.variants ?? undefined;
266280
const variantData = variants
267281
? Object.entries(variants).map(([variantName]) => ({
@@ -334,7 +348,8 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
334348
datapoints.length > 0 &&
335349
datasetName &&
336350
inputs &&
337-
functionName && (
351+
functionName &&
352+
functionConfig && (
338353
<>
339354
<div className="overflow-x-auto rounded border">
340355
<div className="min-w-fit">
@@ -427,6 +442,7 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
427442
setPromise={setPromise}
428443
input={inputs[index]}
429444
functionName={functionName}
445+
functionConfig={functionConfig}
430446
/>
431447
</div>
432448
);
@@ -483,13 +499,14 @@ function useClientInferences(
483499
inputs: DisplayInput[] | undefined,
484500
selectedVariants: string[],
485501
serverInferences: NestedPromiseMap<InferenceResponse>,
502+
functionConfig: FunctionConfig | null,
486503
) {
487504
const { map, setPromise, setMap } =
488505
useNestedPromiseMap<InferenceResponse>(serverInferences);
489506

490507
// Single combined effect to handle both server inferences and client inferences
491508
useEffect(() => {
492-
if (!functionName || !datapoints || !inputs) return;
509+
if (!functionName || !datapoints || !inputs || !functionConfig) return;
493510

494511
// First check if we need any updates
495512
let needsUpdate = false;
@@ -529,24 +546,28 @@ function useClientInferences(
529546
newMap.set(variant, variantMap);
530547
}
531548

532-
const request = prepareInferenceActionRequest({
533-
source: "clickhouse_datapoint",
534-
input,
535-
functionName,
536-
variant: variant,
537-
tool_params:
538-
datapoint?.type === "chat"
539-
? (datapoint.tool_params ?? undefined)
540-
: undefined,
541-
output_schema:
542-
datapoint?.type === "json" ? datapoint.output_schema : null,
543-
// The default is write_only but we do off in the playground
544-
cache_options: {
545-
max_age_s: null,
546-
enabled: "off",
547-
},
548-
dryrun: true,
549-
});
549+
const request = {
550+
...prepareInferenceActionRequest({
551+
source: "clickhouse_datapoint",
552+
input,
553+
functionName,
554+
variant: variant,
555+
tool_params:
556+
datapoint?.type === "chat"
557+
? (datapoint.tool_params ?? undefined)
558+
: undefined,
559+
output_schema:
560+
datapoint?.type === "json" ? datapoint.output_schema : null,
561+
// The default is write_only but we do off in the playground
562+
cache_options: {
563+
max_age_s: null,
564+
enabled: "off",
565+
},
566+
dryrun: true,
567+
functionConfig,
568+
}),
569+
...getExtraInferenceOptions(),
570+
};
550571
const formData = new FormData();
551572
formData.append("data", JSON.stringify(request));
552573
const responsePromise = fetch("/api/tensorzero/inference", {
@@ -564,7 +585,14 @@ function useClientInferences(
564585

565586
return newMap;
566587
});
567-
}, [functionName, datapoints, inputs, selectedVariants, setMap]);
588+
}, [
589+
functionName,
590+
datapoints,
591+
inputs,
592+
selectedVariants,
593+
setMap,
594+
functionConfig,
595+
]);
568596

569597
return { map, setPromise, setMap };
570598
}

ui/app/routes/playground/utils.ts

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ import type { DisplayInput } from "~/utils/clickhouse/common";
22
import type {
33
Datapoint as TensorZeroDatapoint,
44
InferenceResponse,
5+
FunctionConfig,
56
} from "tensorzero-node";
67
import { prepareInferenceActionRequest } from "../api/tensorzero/inference.utils";
8+
import { getExtraInferenceOptions } from "~/utils/feature_flags";
79

810
export function refreshClientInference(
911
setPromise: (
@@ -15,23 +17,29 @@ export function refreshClientInference(
1517
datapoint: TensorZeroDatapoint,
1618
variantName: string,
1719
functionName: string,
20+
functionConfig: FunctionConfig,
1821
) {
19-
const request = prepareInferenceActionRequest({
20-
source: "clickhouse_datapoint",
21-
input,
22-
functionName,
23-
variant: variantName,
24-
tool_params:
25-
datapoint?.type === "chat"
26-
? (datapoint.tool_params ?? undefined)
27-
: undefined,
28-
output_schema: datapoint?.type === "json" ? datapoint.output_schema : null,
29-
cache_options: {
30-
max_age_s: null,
31-
enabled: "off",
32-
},
33-
dryrun: true,
34-
});
22+
const request = {
23+
...prepareInferenceActionRequest({
24+
source: "clickhouse_datapoint",
25+
input,
26+
functionName,
27+
variant: variantName,
28+
tool_params:
29+
datapoint?.type === "chat"
30+
? (datapoint.tool_params ?? undefined)
31+
: undefined,
32+
output_schema:
33+
datapoint?.type === "json" ? datapoint.output_schema : null,
34+
cache_options: {
35+
max_age_s: null,
36+
enabled: "off",
37+
},
38+
dryrun: true,
39+
functionConfig,
40+
}),
41+
...getExtraInferenceOptions(),
42+
};
3543
// The API endpoint takes form data so we need to stringify it and send as data
3644
const formData = new FormData();
3745
formData.append("data", JSON.stringify(request));

ui/app/utils/clickhouse/inference.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
import type {
1212
JsonInferenceOutput,
1313
ContentBlockChatOutput,
14+
Tool,
1415
} from "tensorzero-node";
1516

1617
// Zod schemas for ToolCallConfigDatabaseInsert
@@ -19,7 +20,7 @@ export const toolSchema = z.object({
1920
parameters: JsonValueSchema,
2021
name: z.string(),
2122
strict: z.boolean(),
22-
});
23+
}) satisfies z.ZodType<Tool>;
2324

2425
export const toolChoiceSchema = z.union([
2526
z.literal("none"),

0 commit comments

Comments
 (0)