|
1 | 1 | import pytest |
2 | 2 | from inline_snapshot import snapshot |
3 | 3 | from pydantic_ai import ToolOutput |
4 | | -from pydantic_ai.models.openai import OpenAIChatModel |
| 4 | +from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings |
| 5 | +from pydantic_ai.models.test import TestModel |
5 | 6 | from pydantic_ai.providers.openai import OpenAIProvider |
6 | 7 |
|
7 | 8 | from cragents import Anchor, Constrain, CRAgent, Free, Think, UseTools, vllm_model_profile |
|
24 | 25 | ] |
25 | 26 |
|
26 | 27 |
|
| 28 | +# ── end-to-end set_guide output type tests ──────────────────────────────────── |
| 29 | + |
| 30 | + |
27 | 31 | async def test_default_agent_output(): |
28 | 32 | agent = CRAgent(model) |
29 | 33 | await agent.set_guide(generation_sequence) |
@@ -134,3 +138,102 @@ async def test_mixed_output_type(): |
134 | 138 | } |
135 | 139 | } |
136 | 140 | ) |
| 141 | + |
| 142 | + |
| 143 | +# ── set_guide error handling and model settings ─────────────────────────────── |
| 144 | + |
| 145 | + |
| 146 | +async def test_set_guide_requires_openai_model(): |
| 147 | + agent = CRAgent(TestModel()) |
| 148 | + with pytest.raises(RuntimeError, match="OpenAIChatModel required"): |
| 149 | + await agent.set_guide([Anchor("hi ")]) |
| 150 | + |
| 151 | + |
| 152 | +async def test_set_guide_creates_model_settings_when_none(): |
| 153 | + agent = CRAgent(model) |
| 154 | + assert agent.model_settings is None |
| 155 | + await agent.set_guide([Anchor("hi ")]) |
| 156 | + assert agent.model_settings is not None |
| 157 | + assert "extra_body" in agent.model_settings |
| 158 | + |
| 159 | + |
| 160 | +async def test_set_guide_preserves_existing_model_settings(): |
| 161 | + agent = CRAgent(model, model_settings=OpenAIChatModelSettings(temperature=0.5)) |
| 162 | + await agent.set_guide([Anchor("hi ")]) |
| 163 | + assert agent.model_settings["temperature"] == 0.5 |
| 164 | + assert "extra_body" in agent.model_settings |
| 165 | + |
| 166 | + |
| 167 | +async def test_set_guide_overwrites_on_second_call(): |
| 168 | + agent = CRAgent(model) |
| 169 | + await agent.set_guide([Anchor("first ")]) |
| 170 | + first_grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 171 | + await agent.set_guide([Anchor("second ")]) |
| 172 | + second_grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 173 | + assert first_grammar != second_grammar |
| 174 | + assert "second" in second_grammar |
| 175 | + |
| 176 | + |
| 177 | +# ── set_guide UseTools schema handling ──────────────────────────────────────── |
| 178 | + |
| 179 | + |
| 180 | +async def test_set_guide_explicit_use_tools_schema_not_overwritten(): |
| 181 | + explicit_schema = {"type": "number"} |
| 182 | + agent = CRAgent(model) |
| 183 | + await agent.set_guide([UseTools(json_schema=explicit_schema)]) |
| 184 | + grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 185 | + assert '"type": "number"' in grammar |
| 186 | + |
| 187 | + |
| 188 | +async def test_set_guide_use_tools_with_registered_tool(): |
| 189 | + agent = CRAgent(model) |
| 190 | + |
| 191 | + @agent.tool_plain |
| 192 | + def my_tool(x: int) -> str: |
| 193 | + return str(x) |
| 194 | + |
| 195 | + await agent.set_guide([UseTools()]) |
| 196 | + grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 197 | + # The tool's parameter schema (containing "x") should appear in the grammar |
| 198 | + assert '"x"' in grammar |
| 199 | + |
| 200 | + |
| 201 | +async def test_set_guide_use_tools_tool_names(): |
| 202 | + agent = CRAgent(model) |
| 203 | + await agent.set_guide([UseTools(json_schema={"type": "string"}, tool_names=["alpha", "beta"])]) |
| 204 | + grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 205 | + assert 'FUNCTION_NAME: ("alpha" | "beta")' in grammar |
| 206 | + |
| 207 | + |
| 208 | +async def test_set_guide_merges_toolset_with_anyof_output(): |
| 209 | + # When the output schema already has anyOf (multiple output types) and the agent |
| 210 | + # also has registered tools, anyOf = toolset_schemas + return_schema["anyOf"] |
| 211 | + agent = CRAgent(model, output_type=[ToolOutput(bool), ToolOutput(int)]) |
| 212 | + |
| 213 | + @agent.tool_plain |
| 214 | + def helper(x: str) -> str: |
| 215 | + return x |
| 216 | + |
| 217 | + await agent.set_guide([UseTools()]) |
| 218 | + grammar = agent.model_settings["extra_body"]["structured_outputs"]["grammar"] |
| 219 | + assert "tool_schema" in grammar |
| 220 | + assert "anyOf" in grammar |
| 221 | + |
| 222 | + |
| 223 | +# ── vllm_model_profile ───────────────────────────────────────────────────────── |
| 224 | + |
| 225 | + |
| 226 | +def test_vllm_profile_strict_tool_definition(): |
| 227 | + assert vllm_model_profile.openai_supports_strict_tool_definition is False |
| 228 | + |
| 229 | + |
| 230 | +def test_vllm_profile_tool_choice_required(): |
| 231 | + assert vllm_model_profile.openai_supports_tool_choice_required is False |
| 232 | + |
| 233 | + |
| 234 | +def test_vllm_profile_json_object_output(): |
| 235 | + assert vllm_model_profile.supports_json_object_output is False |
| 236 | + |
| 237 | + |
| 238 | +def test_vllm_profile_json_schema_output(): |
| 239 | + assert vllm_model_profile.supports_json_schema_output is True |
0 commit comments