Skip to content

Commit 7ca3dee

Browse files
authored
use native client for inference in UI (tensorzero#2975)
* added ClientInferenceParams and inferenceResponse to TS bindings * refactored getting native tensorzero client into its own file * refactored utilites for inference * fixed import issues * fixed bindings generation * implemented a rust-safe stringify helper * implemented type safe conversion to rust input type * more typing fixes * inference works via rust client for server inferences in playground * deprecated ts inference client * cleaned up PR * consolidated building bindings into cargo tsbuild * added a comment about caching settings in playground * inference should throw if `extra_body` is set * fixed check for extra body * fixed error handling for try with variant * fixed issue with merge
1 parent 12c9d60 commit 7ca3dee

91 files changed

Lines changed: 993 additions & 608 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.cargo/config.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,11 @@ build-e2e = "build --bin gateway --features e2e_tests"
4444
run-e2e = "run --bin gateway --features e2e_tests -- --config-file tensorzero-core/tests/e2e/tensorzero.toml"
4545
watch-e2e = "watch -x run-e2e"
4646

47-
# Export Typescript bindings for TensorZero
48-
tsbuild = ["test", "export_bindings", "-p", "tensorzero-core"]
47+
tsbuild = [
48+
"test",
49+
"export_bindings",
50+
"-p",
51+
"tensorzero-core",
52+
"-p",
53+
"tensorzero",
54+
] # Export Typescript bindings for TensorZero

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

clients/rust/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ reqwest-eventsource = { workspace = true }
2121
async-stream = { workspace = true }
2222
tokio-stream = { workspace = true }
2323
tensorzero-core = { path = "../../tensorzero-core" }
24+
ts-rs = { workspace = true }
2425
url = { workspace = true }
2526
thiserror = "2.0.11"
2627
pyo3 = { workspace = true, optional = true }

clients/rust/src/client_inference_params.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ use crate::client_input::{test_client_input_to_input, ClientInput};
2121
// This is a copy-paste of the `Params` struct from `tensorzero_core::endpoints::inference::Params`.
2222
// with just the `credentials` field adjusted to allow serialization.
2323
/// The expected payload is a JSON object with the following fields:
24-
#[derive(Clone, Debug, Serialize, Default)]
24+
#[derive(Clone, Debug, Deserialize, Serialize, Default, ts_rs::TS)]
25+
#[ts(export)]
2526
pub struct ClientInferenceParams {
2627
// The function name. Exactly one of `function_name` or `model_name` must be provided.
2728
pub function_name: Option<String>,
@@ -61,16 +62,27 @@ pub struct ClientInferenceParams {
6162
// If provided for a JSON inference, the inference will use the specified output schema instead of the
6263
// configured one. We only lazily validate this schema.
6364
pub output_schema: Option<Value>,
65+
#[ts(type = "Map<string, string>")]
6466
pub credentials: HashMap<String, ClientSecretString>,
6567
pub cache_options: CacheParamsOptions,
6668
/// If `true`, add an `original_response` field to the response, containing the raw string response from the model.
6769
/// Note that for complex variants (e.g. `experimental_best_of_n_sampling`), the response may not contain `original_response`
6870
/// if the fuser/judge model failed
6971
#[serde(default)]
7072
pub include_original_response: bool,
73+
// NOTE: Currently, ts_rs does not handle #[serde(transparent)] correctly,
74+
// so we disable the type generation for the extra_body and extra_headers fields.
75+
// I tried doing a direct #[ts(type = "InferenceExtraBody[]")] and
76+
// a #[ts(as = "Vec<InferenceExtraBody>")] and these would generate the types but then
77+
// type checking would fail because the ClientInferenceParams struct would not be
78+
// generated with the correct import.
79+
//
80+
// Not sure if this is solvable with the existing crate.
7181
#[serde(default)]
82+
#[ts(skip)]
7283
pub extra_body: UnfilteredInferenceExtraBody,
7384
#[serde(default)]
85+
#[ts(skip)]
7486
pub extra_headers: UnfilteredInferenceExtraHeaders,
7587
pub internal_dynamic_variant_config: Option<UninitializedVariantInfo>,
7688
}

clients/rust/src/client_input.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ use tensorzero_derive::TensorZeroDeserialize;
1111
// Like the normal `Input` type, but with `ClientInputMessage` instead of `InputMessage`.
1212
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Default)]
1313
#[serde(deny_unknown_fields)]
14+
#[derive(ts_rs::TS)]
15+
#[ts(export)]
1416
pub struct ClientInput {
1517
#[serde(skip_serializing_if = "Option::is_none")]
1618
pub system: Option<Value>,
@@ -21,6 +23,8 @@ pub struct ClientInput {
2123
// Like the normal `InputMessage` type, but with `ClientInputMessageContent` instead of `InputMessageContent`.
2224
#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
2325
#[serde(deny_unknown_fields)]
26+
#[derive(ts_rs::TS)]
27+
#[ts(export)]
2428
pub struct ClientInputMessage {
2529
pub role: Role,
2630
#[serde(deserialize_with = "deserialize_content")]
@@ -30,6 +34,8 @@ pub struct ClientInputMessage {
3034
#[derive(Clone, Debug, TensorZeroDeserialize, Serialize, PartialEq)]
3135
#[serde(tag = "type")]
3236
#[serde(rename_all = "snake_case")]
37+
#[derive(ts_rs::TS)]
38+
#[ts(export)]
3339
pub enum ClientInputMessageContent {
3440
Text(TextKind),
3541
ToolCall(ToolCallInput),

internal/tensorzero-node/build-bindings.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ mkdir -p lib/bindings
1313
# Generate TypeScript bindings from Rust code
1414
echo "Generating TypeScript bindings from Rust..."
1515
cd ../..
16-
TS_RS_EXPORT_DIR="../internal/tensorzero-node/lib/bindings" cargo tsbuild
16+
TS_RS_EXPORT_DIR="$(pwd)/internal/tensorzero-node/lib/bindings" cargo tsbuild
1717
cd internal/tensorzero-node
1818

1919
# Generate index file

internal/tensorzero-node/lib/bindings/BestOfNEvaluatorConfig.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2-
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3-
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
42
import type { JsonMode } from "./JsonMode";
53
import type { PathWithContents } from "./PathWithContents";
64
import type { RetryConfig } from "./RetryConfig";
@@ -20,6 +18,4 @@ export type BestOfNEvaluatorConfig = {
2018
stop_sequences: Array<string> | null;
2119
json_mode: JsonMode | null;
2220
retries: RetryConfig;
23-
extra_body: ExtraBodyConfig | null;
24-
extra_headers: ExtraHeadersConfig | null;
2521
};

internal/tensorzero-node/lib/bindings/ExtraHeader.ts renamed to internal/tensorzero-node/lib/bindings/CacheEnabledMode.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
22

3-
export type ExtraHeader = { name: string } & ({ value: string } | "delete");
3+
export type CacheEnabledMode = "on" | "off" | "read_only" | "write_only";
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { CacheEnabledMode } from "./CacheEnabledMode";
3+
4+
export type CacheParamsOptions = {
5+
max_age_s: number | null;
6+
enabled: CacheEnabledMode;
7+
};

internal/tensorzero-node/lib/bindings/ChainOfThoughtConfig.ts

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2-
import type { ExtraBodyConfig } from "./ExtraBodyConfig";
3-
import type { ExtraHeadersConfig } from "./ExtraHeadersConfig";
42
import type { JsonMode } from "./JsonMode";
53
import type { PathWithContents } from "./PathWithContents";
64
import type { RetryConfig } from "./RetryConfig";
@@ -20,6 +18,4 @@ export type ChainOfThoughtConfig = {
2018
stop_sequences: Array<string> | null;
2119
json_mode: JsonMode | null;
2220
retries: RetryConfig;
23-
extra_body: ExtraBodyConfig | null;
24-
extra_headers: ExtraHeadersConfig | null;
2521
};

0 commit comments

Comments
 (0)