From 86ff84664f9c5d92e0384619703c672ab7aa7ba7 Mon Sep 17 00:00:00 2001 From: Xeophon <46377542+xeophon@users.noreply.github.com> Date: Tue, 23 Jun 2026 06:16:22 +0200 Subject: [PATCH] Initialize and reuse V1 renderer pools efficiently --- tests/v1/test_serve.py | 116 ++++++++++++++++++++++++++++++++++ verifiers/v1/clients/train.py | 24 ++++--- verifiers/v1/serve/server.py | 15 +++-- 3 files changed, 141 insertions(+), 14 deletions(-) create mode 100644 tests/v1/test_serve.py diff --git a/tests/v1/test_serve.py b/tests/v1/test_serve.py new file mode 100644 index 000000000..de1e788b8 --- /dev/null +++ b/tests/v1/test_serve.py @@ -0,0 +1,116 @@ +import asyncio +from collections import Counter +from unittest.mock import AsyncMock, Mock + +import pytest +import renderers +import renderers.client + +import verifiers.v1.serve.server as serve_server +from verifiers.v1.clients import EvalClientConfig, TrainClientConfig +from verifiers.v1.clients.train import TrainClient +from verifiers.v1.dialects import ChatDialect +from verifiers.v1.serve.server import EnvServer +from verifiers.v1.types import SamplingConfig + + +def test_env_server_client_cache_keys(monkeypatch): + resolve = Mock(side_effect=lambda _: object()) + monkeypatch.setattr(serve_server, "resolve_client", resolve) + server = object.__new__(EnvServer) + server._clients = {} + + pinned = TrainClientConfig(renderer_model_name="base-model") + pinned_clients = [server._client(pinned, f"adapter-{i}") for i in range(8)] + assert len({id(client) for client in pinned_clients}) == 1 + + server._clients.clear() + unpinned = TrainClientConfig() + assert server._client(unpinned, "adapter-0") is not server._client( + unpinned, "adapter-1" + ) + + server._clients.clear() + eval_config = EvalClientConfig() + assert server._client(eval_config, "model-0") is not server._client( + eval_config, "model-1" + ) + + +async def test_pinned_train_client_routes_512_requests_through_one_pool(monkeypatch): + server = object.__new__(EnvServer) + server._clients = {} + config = TrainClientConfig( + base_url="http://127.0.0.1:1", renderer_model_name="base-model" + ) + adapters = [f"adapter-{i}" for i in range(8)] + contexts = [ + server._context(config, adapter, SamplingConfig()) + for adapter in adapters + for _ in range(64) + ] + shared_client = contexts[0].client + assert isinstance(shared_client, TrainClient) + + renderer = object() + create_pool = Mock(return_value=renderer) + monkeypatch.setattr(renderers, "create_renderer_pool", create_pool) + + generate_mock = AsyncMock( + return_value={ + "request_id": "response", + "content": "ok", + "finish_reason": "stop", + "prompt_ids": [1] * 160, + "completion_ids": [2], + "completion_logprobs": [-0.1], + } + ) + monkeypatch.setattr(renderers.client, "generate", generate_mock) + responses = [] + for start in range(0, len(contexts), 128): + responses.extend( + await asyncio.gather( + *( + ctx.client.get_response( + ChatDialect(), + {"messages": [{"role": "user", "content": "hello"}]}, + ctx.model, + ctx.sampling, + ) + for ctx in contexts[start : start + 128] + ) + ) + ) + + assert create_pool.call_args.args == ("base-model", None) + assert create_pool.call_args.kwargs == {"size": 1} + assert create_pool.call_count == 1 + assert Counter(call.kwargs["model"] for call in generate_mock.call_args_list) == { + adapter: 64 for adapter in adapters + } + assert sum(response.usage.total_tokens for response in responses) == 82_432 + + close = AsyncMock() + monkeypatch.setattr(shared_client, "close", close) + for client in server._clients.values(): + await client.close() + close.assert_awaited_once() + + +async def test_renderer_pool_initialization_failure_is_cached(monkeypatch): + failure = RuntimeError("renderer failed") + create_pool = Mock(side_effect=failure) + monkeypatch.setattr(renderers, "create_renderer_pool", create_pool) + client = TrainClient(AsyncMock(), renderer_model_name="base-model") + + results = await asyncio.gather( + *(client._renderer_pool(f"adapter-{i}") for i in range(32)), + return_exceptions=True, + ) + + assert create_pool.call_count == 1 + assert all(str(result) == "renderer failed" for result in results) + with pytest.raises(RuntimeError, match="renderer failed"): + await client._renderer_pool("another-adapter") + assert create_pool.call_count == 1 diff --git a/verifiers/v1/clients/train.py b/verifiers/v1/clients/train.py index 918302379..26eb459a5 100644 --- a/verifiers/v1/clients/train.py +++ b/verifiers/v1/clients/train.py @@ -8,6 +8,7 @@ needs a running vLLM engine. """ +import asyncio import json from collections.abc import Mapping from typing import Any @@ -15,7 +16,7 @@ from openai import AsyncOpenAI, OpenAIError from renderers import RenderedTokens from renderers import OverlongPromptError as RendererOverlongPromptError -from renderers import RendererConfig +from renderers import RendererConfig, RendererPool from verifiers.v1.clients.client import SESSION_ID_HEADER, Client from verifiers.v1.dialects import FINISH_REASONS, ChatDialect, Dialect, parse_tools @@ -190,16 +191,23 @@ def __init__( self.pool_size = pool_size self.config = config self.renderer_model_name = renderer_model_name - self._pool = None + self._pool_task: asyncio.Task[RendererPool] | None = None - def _renderer_pool(self, model: str): - if self._pool is None: + async def _renderer_pool(self, model: str) -> RendererPool: + if self._pool_task is None: from renderers import create_renderer_pool - self._pool = create_renderer_pool( - self.renderer_model_name or model, self.config, size=self.pool_size + # Store one off-loop task before yielding so concurrent first calls initialize once. + self._pool_task = asyncio.create_task( + asyncio.to_thread( + create_renderer_pool, + self.renderer_model_name or model, + self.config, + size=self.pool_size, + ) ) - return self._pool + # Shield waiter cancellation; the task caches either the pool or its startup failure. + return await asyncio.shield(self._pool_task) async def get_response( self, @@ -230,7 +238,7 @@ async def get_response( tools = parse_tools(body.get("tools")) else: prompt, tools = dialect.parse_request(body) - renderer = self._renderer_pool(model) + renderer = await self._renderer_pool(model) from renderers.client import _maybe_offload, generate wire_tools = [tool_to_wire(t) for t in tools] if tools else None diff --git a/verifiers/v1/serve/server.py b/verifiers/v1/serve/server.py index eb6b24af8..38889e4f2 100644 --- a/verifiers/v1/serve/server.py +++ b/verifiers/v1/serve/server.py @@ -28,7 +28,7 @@ from verifiers.utils.serve_utils import msgpack_encoder from verifiers.v1.clients import RolloutContext, resolve_client from verifiers.v1.clients.client import Client -from verifiers.v1.clients.config import ClientConfig +from verifiers.v1.clients.config import ClientConfig, TrainClientConfig from verifiers.v1.decorators import discover_decorated from verifiers.v1.env import EnvConfig, Environment from verifiers.v1.serve.types import ( @@ -61,7 +61,7 @@ def __init__( ) self._clients: dict[ tuple[str, str], Client - ] = {} # (client_config, model) -> Client + ] = {} # (client_config, tokenizer/model) -> Client self.ctx = zmq.asyncio.Context() self.frontend = self.ctx.socket(zmq.ROUTER) @@ -96,10 +96,13 @@ def run_server(cls, address_queue=None, **kwargs) -> None: pass def _client(self, client_config: ClientConfig, model: str) -> Client: - """Resolve (and cache) a `Client` for this config+model. Cached because a - renderer client builds the model's tokenizer pool on first use — doing that - per request would be ruinous.""" - key = (client_config.model_dump_json(), model) + """Resolve and cache a client, sharing explicitly pinned renderer models.""" + cache_model = ( + client_config.renderer_model_name + if isinstance(client_config, TrainClientConfig) + else None + ) or model + key = (client_config.model_dump_json(), cache_model) if key not in self._clients: self._clients[key] = resolve_client(client_config) return self._clients[key]