@@ -21,7 +21,10 @@ import {
2121} from "~/routes/api/tensorzero/inference.utils" ;
2222import { resolveInput } from "~/utils/resolve.server" ;
2323import { X } from "lucide-react" ;
24- import type { Datapoint as TensorZeroDatapoint } from "tensorzero-node" ;
24+ import type {
25+ FunctionConfig ,
26+ Datapoint as TensorZeroDatapoint ,
27+ } from "tensorzero-node" ;
2528import type { DisplayInput } from "~/utils/clickhouse/common" ;
2629import { useCallback , useEffect , useMemo , useState } from "react" ;
2730import { Button } from "~/components/ui/button" ;
@@ -148,26 +151,30 @@ export async function loader({ request }: Route.LoaderArgs) {
148151 datapoint : TensorZeroDatapoint ,
149152 functionName : string ,
150153 variantName : string ,
154+ functionConfig : FunctionConfig ,
151155 ) => {
152- const request = prepareInferenceActionRequest ( {
153- source : "clickhouse_datapoint" ,
154- input,
155- functionName,
156- variant : variantName ,
157- tool_params :
158- datapoint ?. type === "chat"
159- ? ( datapoint . tool_params ?? undefined )
160- : undefined ,
161- output_schema :
162- datapoint ?. type === "json" ? datapoint . output_schema : null ,
163- // The default is write_only but we do off in the playground
164- cache_options : {
165- max_age_s : null ,
166- enabled : "off" ,
167- } ,
168- dryrun : true ,
156+ const request = {
157+ ...prepareInferenceActionRequest ( {
158+ source : "clickhouse_datapoint" ,
159+ input,
160+ functionName,
161+ variant : variantName ,
162+ tool_params :
163+ datapoint ?. type === "chat"
164+ ? ( datapoint . tool_params ?? undefined )
165+ : undefined ,
166+ output_schema :
167+ datapoint ?. type === "json" ? datapoint . output_schema : null ,
168+ // The default is write_only but we do off in the playground
169+ cache_options : {
170+ max_age_s : null ,
171+ enabled : "off" ,
172+ } ,
173+ dryrun : true ,
174+ functionConfig,
175+ } ) ,
169176 ...getExtraInferenceOptions ( ) ,
170- } ) ;
177+ } ;
171178 const nativeClient = await getNativeTensorZeroClient ( ) ;
172179 const inferenceResponse = await nativeClient . inference ( request ) ;
173180 return inferenceResponse ;
@@ -183,7 +190,7 @@ export async function loader({ request }: Route.LoaderArgs) {
183190 for ( const variant of selectedVariants ) {
184191 serverInferences . set ( variant , new Map ( ) ) ;
185192 }
186- if ( datapoints && inputs && functionName ) {
193+ if ( datapoints && inputs && functionName && functionConfig ) {
187194 for ( let index = 0 ; index < datapoints . length ; index ++ ) {
188195 const datapoint = datapoints [ index ] ;
189196 const input = inputs [ index ] ;
@@ -192,7 +199,13 @@ export async function loader({ request }: Route.LoaderArgs) {
192199 . get ( variant )
193200 ?. set (
194201 datapoint . id ,
195- serverInference ( input , datapoint , functionName , variant ) ,
202+ serverInference (
203+ input ,
204+ datapoint ,
205+ functionName ,
206+ variant ,
207+ functionConfig ,
208+ ) ,
196209 ) ;
197210 }
198211 }
@@ -248,20 +261,21 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
248261 offset,
249262 limit,
250263 } = loaderData ;
264+ const functionConfig = useFunctionConfig ( functionName ) ;
265+ if ( functionName && ! functionConfig ) {
266+ throw data ( `Function config not found for function ${ functionName } ` , {
267+ status : 404 ,
268+ } ) ;
269+ }
251270 const { map, setPromise } = useClientInferences (
252271 functionName ,
253272 datapoints ,
254273 inputs ,
255274 selectedVariants ,
256275 serverInferences ,
276+ functionConfig ,
257277 ) ;
258278
259- const functionConfig = useFunctionConfig ( functionName ) ;
260- if ( functionName && ! functionConfig ) {
261- throw data ( `Function config not found for function ${ functionName } ` , {
262- status : 404 ,
263- } ) ;
264- }
265279 const variants = functionConfig ?. variants ?? undefined ;
266280 const variantData = variants
267281 ? Object . entries ( variants ) . map ( ( [ variantName ] ) => ( {
@@ -334,7 +348,8 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
334348 datapoints . length > 0 &&
335349 datasetName &&
336350 inputs &&
337- functionName && (
351+ functionName &&
352+ functionConfig && (
338353 < >
339354 < div className = "overflow-x-auto rounded border" >
340355 < div className = "min-w-fit" >
@@ -427,6 +442,7 @@ export default function PlaygroundPage({ loaderData }: Route.ComponentProps) {
427442 setPromise = { setPromise }
428443 input = { inputs [ index ] }
429444 functionName = { functionName }
445+ functionConfig = { functionConfig }
430446 />
431447 </ div >
432448 ) ;
@@ -483,13 +499,14 @@ function useClientInferences(
483499 inputs : DisplayInput [ ] | undefined ,
484500 selectedVariants : string [ ] ,
485501 serverInferences : NestedPromiseMap < InferenceResponse > ,
502+ functionConfig : FunctionConfig | null ,
486503) {
487504 const { map, setPromise, setMap } =
488505 useNestedPromiseMap < InferenceResponse > ( serverInferences ) ;
489506
490507 // Single combined effect to handle both server inferences and client inferences
491508 useEffect ( ( ) => {
492- if ( ! functionName || ! datapoints || ! inputs ) return ;
509+ if ( ! functionName || ! datapoints || ! inputs || ! functionConfig ) return ;
493510
494511 // First check if we need any updates
495512 let needsUpdate = false ;
@@ -529,24 +546,28 @@ function useClientInferences(
529546 newMap . set ( variant , variantMap ) ;
530547 }
531548
532- const request = prepareInferenceActionRequest ( {
533- source : "clickhouse_datapoint" ,
534- input,
535- functionName,
536- variant : variant ,
537- tool_params :
538- datapoint ?. type === "chat"
539- ? ( datapoint . tool_params ?? undefined )
540- : undefined ,
541- output_schema :
542- datapoint ?. type === "json" ? datapoint . output_schema : null ,
543- // The default is write_only but we do off in the playground
544- cache_options : {
545- max_age_s : null ,
546- enabled : "off" ,
547- } ,
548- dryrun : true ,
549- } ) ;
549+ const request = {
550+ ...prepareInferenceActionRequest ( {
551+ source : "clickhouse_datapoint" ,
552+ input,
553+ functionName,
554+ variant : variant ,
555+ tool_params :
556+ datapoint ?. type === "chat"
557+ ? ( datapoint . tool_params ?? undefined )
558+ : undefined ,
559+ output_schema :
560+ datapoint ?. type === "json" ? datapoint . output_schema : null ,
561+ // The default is write_only but we do off in the playground
562+ cache_options : {
563+ max_age_s : null ,
564+ enabled : "off" ,
565+ } ,
566+ dryrun : true ,
567+ functionConfig,
568+ } ) ,
569+ ...getExtraInferenceOptions ( ) ,
570+ } ;
550571 const formData = new FormData ( ) ;
551572 formData . append ( "data" , JSON . stringify ( request ) ) ;
552573 const responsePromise = fetch ( "/api/tensorzero/inference" , {
@@ -564,7 +585,14 @@ function useClientInferences(
564585
565586 return newMap ;
566587 } ) ;
567- } , [ functionName , datapoints , inputs , selectedVariants , setMap ] ) ;
588+ } , [
589+ functionName ,
590+ datapoints ,
591+ inputs ,
592+ selectedVariants ,
593+ setMap ,
594+ functionConfig ,
595+ ] ) ;
568596
569597 return { map, setPromise, setMap } ;
570598}
0 commit comments