Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions tests/v1/test_serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import asyncio
from collections import Counter
from unittest.mock import AsyncMock, Mock

import pytest
import renderers
import renderers.client

import verifiers.v1.serve.server as serve_server
from verifiers.v1.clients import EvalClientConfig, TrainClientConfig
from verifiers.v1.clients.train import TrainClient
from verifiers.v1.dialects import ChatDialect
from verifiers.v1.serve.server import EnvServer
from verifiers.v1.types import SamplingConfig


def test_env_server_client_cache_keys(monkeypatch):
resolve = Mock(side_effect=lambda _: object())
monkeypatch.setattr(serve_server, "resolve_client", resolve)
server = object.__new__(EnvServer)
server._clients = {}

pinned = TrainClientConfig(renderer_model_name="base-model")
pinned_clients = [server._client(pinned, f"adapter-{i}") for i in range(8)]
assert len({id(client) for client in pinned_clients}) == 1

server._clients.clear()
unpinned = TrainClientConfig()
assert server._client(unpinned, "adapter-0") is not server._client(
unpinned, "adapter-1"
)

server._clients.clear()
eval_config = EvalClientConfig()
assert server._client(eval_config, "model-0") is not server._client(
eval_config, "model-1"
)


async def test_pinned_train_client_routes_512_requests_through_one_pool(monkeypatch):
server = object.__new__(EnvServer)
server._clients = {}
config = TrainClientConfig(
base_url="http://127.0.0.1:1", renderer_model_name="base-model"
)
adapters = [f"adapter-{i}" for i in range(8)]
contexts = [
server._context(config, adapter, SamplingConfig())
for adapter in adapters
for _ in range(64)
]
shared_client = contexts[0].client
assert isinstance(shared_client, TrainClient)

renderer = object()
create_pool = Mock(return_value=renderer)
monkeypatch.setattr(renderers, "create_renderer_pool", create_pool)

generate_mock = AsyncMock(
return_value={
"request_id": "response",
"content": "ok",
"finish_reason": "stop",
"prompt_ids": [1] * 160,
"completion_ids": [2],
"completion_logprobs": [-0.1],
}
)
monkeypatch.setattr(renderers.client, "generate", generate_mock)
responses = []
for start in range(0, len(contexts), 128):
responses.extend(
await asyncio.gather(
*(
ctx.client.get_response(
ChatDialect(),
{"messages": [{"role": "user", "content": "hello"}]},
ctx.model,
ctx.sampling,
)
for ctx in contexts[start : start + 128]
)
)
)

assert create_pool.call_args.args == ("base-model", None)
assert create_pool.call_args.kwargs == {"size": 1}
assert create_pool.call_count == 1
assert Counter(call.kwargs["model"] for call in generate_mock.call_args_list) == {
adapter: 64 for adapter in adapters
}
assert sum(response.usage.total_tokens for response in responses) == 82_432

close = AsyncMock()
monkeypatch.setattr(shared_client, "close", close)
for client in server._clients.values():
await client.close()
close.assert_awaited_once()


async def test_renderer_pool_initialization_failure_is_cached(monkeypatch):
failure = RuntimeError("renderer failed")
create_pool = Mock(side_effect=failure)
monkeypatch.setattr(renderers, "create_renderer_pool", create_pool)
client = TrainClient(AsyncMock(), renderer_model_name="base-model")

results = await asyncio.gather(
*(client._renderer_pool(f"adapter-{i}") for i in range(32)),
return_exceptions=True,
)

assert create_pool.call_count == 1
assert all(str(result) == "renderer failed" for result in results)
with pytest.raises(RuntimeError, match="renderer failed"):
await client._renderer_pool("another-adapter")
assert create_pool.call_count == 1
24 changes: 16 additions & 8 deletions verifiers/v1/clients/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
needs a running vLLM engine.
"""

import asyncio
import json
from collections.abc import Mapping
from typing import Any

from openai import AsyncOpenAI, OpenAIError
from renderers import RenderedTokens
from renderers import OverlongPromptError as RendererOverlongPromptError
from renderers import RendererConfig
from renderers import RendererConfig, RendererPool

from verifiers.v1.clients.client import SESSION_ID_HEADER, Client
from verifiers.v1.dialects import FINISH_REASONS, ChatDialect, Dialect, parse_tools
Expand Down Expand Up @@ -190,16 +191,23 @@ def __init__(
self.pool_size = pool_size
self.config = config
self.renderer_model_name = renderer_model_name
self._pool = None
self._pool_task: asyncio.Task[RendererPool] | None = None

def _renderer_pool(self, model: str):
if self._pool is None:
async def _renderer_pool(self, model: str) -> RendererPool:
if self._pool_task is None:
from renderers import create_renderer_pool

self._pool = create_renderer_pool(
self.renderer_model_name or model, self.config, size=self.pool_size
# Store one off-loop task before yielding so concurrent first calls initialize once.
self._pool_task = asyncio.create_task(
asyncio.to_thread(
Comment on lines +201 to +202

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Pin renderers before offloading pool creation

This moves create_renderer_pool into a worker thread, but the package metadata still allows renderers>=0.1.8.dev40 and the lockfile still resolves 0.1.8.dev43; the required upstream fix (renderers#91) is still open. Fresh evidence beyond the prior comment is that this commit did not bump or lock the dependency, so installs resolving the current range can still run the old process-wide fastokens/Transformers patch concurrently with environment or harness tokenizer loads during the first train request.

Useful? React with 👍 / 👎.

create_renderer_pool,
self.renderer_model_name or model,
self.config,
size=self.pool_size,
)
Comment thread
cursor[bot] marked this conversation as resolved.
)
return self._pool
# Shield waiter cancellation; the task caches either the pool or its startup failure.
return await asyncio.shield(self._pool_task)

async def get_response(
self,
Expand Down Expand Up @@ -230,7 +238,7 @@ async def get_response(
tools = parse_tools(body.get("tools"))
else:
prompt, tools = dialect.parse_request(body)
renderer = self._renderer_pool(model)
renderer = await self._renderer_pool(model)
from renderers.client import _maybe_offload, generate

wire_tools = [tool_to_wire(t) for t in tools] if tools else None
Expand Down
15 changes: 9 additions & 6 deletions verifiers/v1/serve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from verifiers.utils.serve_utils import msgpack_encoder
from verifiers.v1.clients import RolloutContext, resolve_client
from verifiers.v1.clients.client import Client
from verifiers.v1.clients.config import ClientConfig
from verifiers.v1.clients.config import ClientConfig, TrainClientConfig
from verifiers.v1.decorators import discover_decorated
from verifiers.v1.env import EnvConfig, Environment
from verifiers.v1.serve.types import (
Expand Down Expand Up @@ -61,7 +61,7 @@ def __init__(
)
self._clients: dict[
tuple[str, str], Client
] = {} # (client_config, model) -> Client
] = {} # (client_config, tokenizer/model) -> Client

self.ctx = zmq.asyncio.Context()
self.frontend = self.ctx.socket(zmq.ROUTER)
Expand Down Expand Up @@ -96,10 +96,13 @@ def run_server(cls, address_queue=None, **kwargs) -> None:
pass

def _client(self, client_config: ClientConfig, model: str) -> Client:
"""Resolve (and cache) a `Client` for this config+model. Cached because a
renderer client builds the model's tokenizer pool on first use — doing that
per request would be ruinous."""
key = (client_config.model_dump_json(), model)
"""Resolve and cache a client, sharing explicitly pinned renderer models."""
cache_model = (
client_config.renderer_model_name
if isinstance(client_config, TrainClientConfig)
else None
) or model
key = (client_config.model_dump_json(), cache_model)
if key not in self._clients:
self._clients[key] = resolve_client(client_config)
return self._clients[key]
Expand Down
Loading