Skip to content

Commit 1b3cbad

Browse files
authored
add an OpenAI compatible embeddings endpoint (tensorzero#3022)
* added batching to existing embeddings code * implemented internal embeddings handler * wip * openai embeddings handler is implemented * embeddings implementation works * wip * added initial rust embeddings e2e test * added many tests for embeddings * added implentation of provider timeouts * added implementation of model timeout * fixed node bindings * added tests for timeouts * forgot a file * added python test for base64 embedding * removed stray edit to toml * cleaned up PR comments * fixed test flag * fixed issue with dicl inserts in tests
1 parent 843af78 commit 1b3cbad

42 files changed

Lines changed: 1310 additions & 159 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# type: ignore
2+
"""
3+
Shared test fixtures for TensorZero OpenAI client tests
4+
"""
5+
6+
import os
7+
8+
import pytest_asyncio
9+
from openai import AsyncOpenAI
10+
11+
TEST_CONFIG_FILE = os.path.join(
12+
os.path.dirname(os.path.abspath(__file__)),
13+
"../../../tensorzero-core/tests/e2e/tensorzero.toml",
14+
)
15+
16+
17+
@pytest_asyncio.fixture
18+
async def async_client():
19+
async with AsyncOpenAI(
20+
api_key="donotuse", base_url="http://localhost:3000/openai/v1"
21+
) as client:
22+
yield client
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# type: ignore
2+
"""
3+
Tests for the TensorZero embeddings API using the OpenAI Python client
4+
5+
These tests cover the embeddings functionality of the TensorZero OpenAI-compatible interface.
6+
7+
To run:
8+
```
9+
pytest tests/test_embeddings.py
10+
```
11+
or
12+
```
13+
uv run pytest tests/test_embeddings.py
14+
```
15+
"""
16+
17+
import pytest
18+
19+
20+
@pytest.mark.asyncio
21+
async def test_basic_embeddings(async_client):
22+
"""Test basic embeddings generation with a single input"""
23+
result = await async_client.embeddings.create(
24+
input="Hello, world!",
25+
model="text-embedding-3-small",
26+
)
27+
28+
# Verify the response structure
29+
assert result.model == "text-embedding-3-small"
30+
assert len(result.data) == 1
31+
assert result.data[0].index == 0
32+
assert result.data[0].object == "embedding"
33+
assert len(result.data[0].embedding) > 0 # Should have embedding vector
34+
assert result.usage.prompt_tokens > 0
35+
assert result.usage.total_tokens > 0
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_basic_embeddings_shorthand(async_client):
40+
"""Test basic embeddings generation with a single input"""
41+
result = await async_client.embeddings.create(
42+
input="Hello, world!",
43+
model="openai::text-embedding-3-large",
44+
)
45+
46+
# Verify the response structure
47+
assert result.model == "openai::text-embedding-3-large"
48+
assert len(result.data) == 1
49+
assert result.data[0].index == 0
50+
assert result.data[0].object == "embedding"
51+
assert len(result.data[0].embedding) > 0 # Should have embedding vector
52+
assert result.usage.prompt_tokens > 0
53+
assert result.usage.total_tokens > 0
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_batch_embeddings(async_client):
58+
"""Test embeddings generation with multiple inputs"""
59+
inputs = [
60+
"Hello, world!",
61+
"How are you today?",
62+
"This is a test of batch embeddings.",
63+
]
64+
65+
result = await async_client.embeddings.create(
66+
input=inputs,
67+
model="text-embedding-3-small",
68+
)
69+
70+
# Verify the response structure
71+
assert result.model == "text-embedding-3-small"
72+
assert len(result.data) == len(inputs)
73+
74+
for i, embedding_data in enumerate(result.data):
75+
assert embedding_data.index == i
76+
assert embedding_data.object == "embedding"
77+
assert len(embedding_data.embedding) > 0
78+
79+
assert result.usage.prompt_tokens > 0
80+
assert result.usage.total_tokens > 0
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_embeddings_with_dimensions(async_client):
85+
"""Test embeddings with specified dimensions"""
86+
result = await async_client.embeddings.create(
87+
input="Test with specific dimensions",
88+
model="text-embedding-3-small",
89+
dimensions=512,
90+
)
91+
92+
# Verify the response structure
93+
assert result.model == "text-embedding-3-small"
94+
assert len(result.data) == 1
95+
# Should match requested dimensions
96+
assert len(result.data[0].embedding) == 512
97+
98+
99+
@pytest.mark.asyncio
100+
async def test_embeddings_with_encoding_format_float(async_client):
101+
"""Test embeddings with different encoding formats"""
102+
result = await async_client.embeddings.create(
103+
input="Test encoding format",
104+
model="text-embedding-3-small",
105+
encoding_format="float",
106+
)
107+
108+
# Verify the response structure
109+
assert result.model == "text-embedding-3-small"
110+
assert len(result.data) == 1
111+
assert isinstance(result.data[0].embedding[0], float)
112+
113+
114+
@pytest.mark.asyncio
115+
async def test_embeddings_with_encoding_format_base64(async_client):
116+
"""Test embeddings with different encoding formats"""
117+
result = await async_client.embeddings.create(
118+
input="Test encoding format",
119+
model="text-embedding-3-small",
120+
encoding_format="base64",
121+
)
122+
123+
# Verify the response structure
124+
assert result.model == "text-embedding-3-small"
125+
assert len(result.data) == 1
126+
assert isinstance(result.data[0].embedding, str)
127+
128+
129+
@pytest.mark.asyncio
130+
async def test_embeddings_with_user_parameter(async_client):
131+
"""Test embeddings with user parameter for tracking"""
132+
user_id = "test_user_123"
133+
result = await async_client.embeddings.create(
134+
input="Test with user parameter",
135+
model="text-embedding-3-small",
136+
user=user_id,
137+
)
138+
139+
# Verify the response structure
140+
assert result.model == "text-embedding-3-small"
141+
assert len(result.data) == 1
142+
assert len(result.data[0].embedding) > 0
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_embeddings_invalid_model_error(async_client):
147+
"""Test that invalid model name raises appropriate error"""
148+
with pytest.raises(Exception) as exc_info:
149+
await async_client.embeddings.create(
150+
input="Test invalid model",
151+
model="tensorzero::model_name::nonexistent_model",
152+
)
153+
154+
# Should get a 404 error for unknown model
155+
assert exc_info.value.status_code == 404
156+
157+
158+
@pytest.mark.asyncio
159+
async def test_embeddings_large_batch(async_client):
160+
"""Test embeddings with a larger batch of inputs"""
161+
# Create a batch of 10 different inputs
162+
inputs = [f"This is test input number {i + 1}" for i in range(10)]
163+
164+
result = await async_client.embeddings.create(
165+
input=inputs,
166+
model="text-embedding-3-small",
167+
)
168+
169+
# Verify the response structure
170+
assert result.model == "text-embedding-3-small"
171+
assert len(result.data) == 10
172+
173+
# Verify each embedding
174+
for i, embedding_data in enumerate(result.data):
175+
assert embedding_data.index == i
176+
assert embedding_data.object == "embedding"
177+
assert len(embedding_data.embedding) > 0
178+
179+
assert result.usage.prompt_tokens > 0
180+
assert result.usage.total_tokens > 0
181+
182+
183+
@pytest.mark.asyncio
184+
async def test_embeddings_consistency(async_client):
185+
"""Test that the same input produces consistent embeddings"""
186+
input_text = "This is a consistency test"
187+
188+
# Generate embeddings twice with the same input
189+
result1 = await async_client.embeddings.create(
190+
input=input_text,
191+
model="text-embedding-3-small",
192+
)
193+
194+
result2 = await async_client.embeddings.create(
195+
input=input_text,
196+
model="text-embedding-3-small",
197+
)
198+
199+
# Both should have the same model and structure
200+
assert result1.model == result2.model
201+
assert len(result1.data) == len(result2.data) == 1
202+
assert len(result1.data[0].embedding) == len(result2.data[0].embedding)
203+
204+
# The embeddings should be identical for the same input
205+
# (assuming deterministic behavior or proper caching)
206+
embedding1 = result1.data[0].embedding
207+
embedding2 = result2.data[0].embedding
208+
209+
# Check that embeddings are similar (allowing for small numerical differences)
210+
for i in range(min(10, len(embedding1))): # Check first 10 dimensions
211+
assert abs(embedding1[i] - embedding2[i]) < 0.01, (
212+
f"Embeddings differ significantly at index {i}"
213+
)

clients/openai-python/tests/test_openai_compatibility.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727
from uuid import UUID
2828

2929
import pytest
30-
import pytest_asyncio
31-
from openai import AsyncOpenAI, BadRequestError
30+
from openai import BadRequestError
3231
from pydantic import BaseModel, ValidationError
3332
from uuid_utils.compat import uuid7
3433

@@ -38,14 +37,6 @@
3837
)
3938

4039

41-
@pytest_asyncio.fixture
42-
async def async_client():
43-
async with AsyncOpenAI(
44-
api_key="donotuse", base_url="http://localhost:3000/openai/v1"
45-
) as client:
46-
yield client
47-
48-
4940
@pytest.mark.asyncio
5041
async def test_async_basic_inference(async_client):
5142
messages = [

gateway/src/main.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ async fn main() {
228228
"/openai/v1/chat/completions",
229229
post(endpoints::openai_compatible::inference_handler),
230230
)
231+
.route(
232+
"/openai/v1/embeddings",
233+
post(endpoints::openai_compatible::embeddings_handler),
234+
)
231235
.route("/feedback", post(endpoints::feedback::feedback_handler))
232236
// Everything above this layer has OpenTelemetry tracing enabled
233237
// Note - we do *not* attach a `OtelInResponseLayer`, as this seems to be incorrect according to the W3C Trace Context spec
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2-
import type { EmbeddingProviderConfig } from "./EmbeddingProviderConfig";
2+
import type { EmbeddingProviderInfo } from "./EmbeddingProviderInfo";
3+
import type { TimeoutsConfig } from "./TimeoutsConfig";
34

45
export type EmbeddingModelConfig = {
56
routing: Array<string>;
6-
providers: { [key in string]?: EmbeddingProviderConfig };
7+
providers: { [key in string]?: EmbeddingProviderInfo };
8+
timeouts: TimeoutsConfig;
79
};
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// This file was generated by [ts-rs](https://github.com/Aleph-Alpha/ts-rs). Do not edit this file manually.
2+
import type { EmbeddingProviderConfig } from "./EmbeddingProviderConfig";
3+
import type { TimeoutsConfig } from "./TimeoutsConfig";
4+
5+
export type EmbeddingProviderInfo = {
6+
inner: EmbeddingProviderConfig;
7+
timeouts: TimeoutsConfig;
8+
provider_name: string;
9+
};

internal/tensorzero-node/lib/bindings/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ export * from "./DynamicJSONSchema";
3232
export * from "./DynamicToolConfig";
3333
export * from "./EmbeddingModelConfig";
3434
export * from "./EmbeddingProviderConfig";
35+
export * from "./EmbeddingProviderInfo";
3536
export * from "./EvaluationConfig";
3637
export * from "./EvaluatorConfig";
3738
export * from "./ExactMatchConfig";

tensorzero-core/src/cache.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::collections::HashMap;
22
use std::sync::Arc;
33

44
use crate::clickhouse::{ClickHouseConnectionInfo, TableName};
5-
use crate::embeddings::{EmbeddingRequest, EmbeddingResponse};
5+
use crate::embeddings::{EmbeddingModelResponse, EmbeddingRequest};
66
use crate::error::{warn_discarded_cache_write, Error, ErrorDetails};
77
use crate::inference::types::file::serialize_with_file_data;
88
use crate::inference::types::{
@@ -387,14 +387,14 @@ pub async fn embedding_cache_lookup(
387387
clickhouse_connection_info: &ClickHouseConnectionInfo,
388388
request: &EmbeddingModelProviderRequest<'_>,
389389
max_age_s: Option<u32>,
390-
) -> Result<Option<EmbeddingResponse>, Error> {
390+
) -> Result<Option<EmbeddingModelResponse>, Error> {
391391
let result = cache_lookup_inner::<EmbeddingCacheData>(
392392
clickhouse_connection_info,
393393
request.get_cache_key()?,
394394
max_age_s,
395395
)
396396
.await?;
397-
Ok(result.map(|result| EmbeddingResponse::from_cache(result, request)))
397+
Ok(result.map(|result| EmbeddingModelResponse::from_cache(result, request)))
398398
}
399399

400400
pub async fn cache_lookup(

tensorzero-core/src/config_parser/tests.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ async fn test_config_from_toml_table_valid() {
124124
assert_eq!(embedding_model.routing, vec!["openai".into()]);
125125
assert_eq!(embedding_model.providers.len(), 1);
126126
let provider = embedding_model.providers.get("openai").unwrap();
127-
assert!(matches!(provider, EmbeddingProviderConfig::OpenAI(_)));
127+
assert!(matches!(provider.inner, EmbeddingProviderConfig::OpenAI(_)));
128128

129129
// Check that the function for the LLM Judge evaluation is added to the functions table
130130
let function = config

0 commit comments

Comments
 (0)