|
| 1 | +"""Test InjectedState with NotRequired state fields. |
| 2 | +
|
| 3 | +This tests the fix for https://github.com/langchain-ai/langchain/issues/35585 |
| 4 | +
|
| 5 | +When using InjectedState(<field>) on a tool parameter, and the referenced field is |
| 6 | +declared as NotRequired in the custom state schema, the ToolNode should gracefully |
| 7 | +handle missing fields by injecting None instead of raising KeyError. |
| 8 | +""" |
| 9 | + |
| 10 | +import sys |
| 11 | +from typing import Annotated |
| 12 | + |
| 13 | +import pytest |
| 14 | +from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage |
| 15 | +from langchain_core.tools import tool |
| 16 | +from langgraph.graph.message import add_messages |
| 17 | +from pydantic import BaseModel, Field |
| 18 | +from typing_extensions import NotRequired |
| 19 | + |
| 20 | +from langgraph.prebuilt import InjectedState, ToolNode, create_react_agent |
| 21 | +from langgraph.prebuilt.chat_agent_executor import AgentState |
| 22 | + |
| 23 | +from .model import FakeToolCallingModel |
| 24 | + |
| 25 | + |
| 26 | +class CustomAgentStateWithNotRequired(AgentState): |
| 27 | + """Custom state with a NotRequired field (TypedDict style).""" |
| 28 | + |
| 29 | + city: NotRequired[str] |
| 30 | + |
| 31 | + |
| 32 | +class CustomAgentStatePydanticWithDefault(BaseModel): |
| 33 | + """Custom state with Optional field and default (Pydantic style).""" |
| 34 | + |
| 35 | + messages: Annotated[list[AnyMessage], add_messages] |
| 36 | + remaining_steps: int = Field(default=10) |
| 37 | + city: str | None = Field(default=None) |
| 38 | + |
| 39 | + |
| 40 | +@tool |
| 41 | +def get_weather(city: Annotated[str | None, InjectedState("city")] = None) -> str: |
| 42 | + """Get weather for a given city.""" |
| 43 | + if city is None: |
| 44 | + return "No city provided" |
| 45 | + return f"It's always sunny in {city}!" |
| 46 | + |
| 47 | + |
| 48 | +def _create_mock_runtime( |
| 49 | + state: dict | None = None, |
| 50 | + store=None, |
| 51 | +): |
| 52 | + """Create a mock Runtime for testing ToolNode directly.""" |
| 53 | + from unittest.mock import Mock |
| 54 | + |
| 55 | + from langgraph.runtime import Runtime |
| 56 | + |
| 57 | + mock_runtime = Mock(spec=Runtime) |
| 58 | + mock_runtime.context = {} |
| 59 | + return mock_runtime |
| 60 | + |
| 61 | + |
| 62 | +def _create_config_with_runtime(store=None, state=None): |
| 63 | + """Create a RunnableConfig with mocked runtime for direct ToolNode testing.""" |
| 64 | + from langgraph.prebuilt.tool_node import ToolRuntime |
| 65 | + |
| 66 | + tool_runtime = ToolRuntime( |
| 67 | + state=state or {}, |
| 68 | + config={}, |
| 69 | + context={}, |
| 70 | + store=store, |
| 71 | + stream_writer=None, |
| 72 | + tool_call_id="test_id", |
| 73 | + ) |
| 74 | + return { |
| 75 | + "configurable": { |
| 76 | + "__pregel_runtime": _create_mock_runtime(), |
| 77 | + "__tool_runtime__": tool_runtime, |
| 78 | + } |
| 79 | + } |
| 80 | + |
| 81 | + |
| 82 | +@pytest.mark.skipif( |
| 83 | + sys.version_info < (3, 11), |
| 84 | + reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11", |
| 85 | +) |
| 86 | +def test_injected_state_not_required_field_missing_injects_none(): |
| 87 | + """Test that InjectedState with NotRequired field injects None when field is missing. |
| 88 | +
|
| 89 | + This verifies the fix for https://github.com/langchain-ai/langchain/issues/35585 |
| 90 | + """ |
| 91 | + tool_node = ToolNode([get_weather]) |
| 92 | + |
| 93 | + tool_call = { |
| 94 | + "name": "get_weather", |
| 95 | + "args": {}, |
| 96 | + "id": "call_1", |
| 97 | + "type": "tool_call", |
| 98 | + } |
| 99 | + ai_msg = AIMessage("Let me check the weather", tool_calls=[tool_call]) |
| 100 | + |
| 101 | + # State WITHOUT the "city" field - should inject None instead of raising KeyError |
| 102 | + state_without_city: CustomAgentStateWithNotRequired = { |
| 103 | + "messages": [HumanMessage("What's the weather?"), ai_msg], |
| 104 | + } |
| 105 | + |
| 106 | + result = tool_node.invoke( |
| 107 | + state_without_city, |
| 108 | + config=_create_config_with_runtime(state=state_without_city), |
| 109 | + ) |
| 110 | + |
| 111 | + assert len(result["messages"]) == 1 |
| 112 | + tool_msg = result["messages"][0] |
| 113 | + assert isinstance(tool_msg, ToolMessage) |
| 114 | + assert "No city provided" in tool_msg.content |
| 115 | + |
| 116 | + |
| 117 | +@pytest.mark.skipif( |
| 118 | + sys.version_info < (3, 11), |
| 119 | + reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11", |
| 120 | +) |
| 121 | +def test_injected_state_not_required_field_present_works(): |
| 122 | + """Test that InjectedState with NotRequired field works when field IS present.""" |
| 123 | + tool_node = ToolNode([get_weather]) |
| 124 | + |
| 125 | + tool_call = { |
| 126 | + "name": "get_weather", |
| 127 | + "args": {}, |
| 128 | + "id": "call_1", |
| 129 | + "type": "tool_call", |
| 130 | + } |
| 131 | + ai_msg = AIMessage("Let me check the weather", tool_calls=[tool_call]) |
| 132 | + |
| 133 | + # State WITH the "city" field - this should work |
| 134 | + state_with_city: CustomAgentStateWithNotRequired = { |
| 135 | + "messages": [HumanMessage("What's the weather?"), ai_msg], |
| 136 | + "city": "San Francisco", |
| 137 | + } |
| 138 | + |
| 139 | + result = tool_node.invoke( |
| 140 | + state_with_city, |
| 141 | + config=_create_config_with_runtime(state=state_with_city), |
| 142 | + ) |
| 143 | + |
| 144 | + assert len(result["messages"]) == 1 |
| 145 | + tool_msg = result["messages"][0] |
| 146 | + assert isinstance(tool_msg, ToolMessage) |
| 147 | + assert "San Francisco" in tool_msg.content |
| 148 | + |
| 149 | + |
| 150 | +@pytest.mark.skipif( |
| 151 | + sys.version_info < (3, 11), |
| 152 | + reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11", |
| 153 | +) |
| 154 | +def test_create_react_agent_injected_state_not_required_field_missing(): |
| 155 | + """Test create_react_agent with InjectedState using NotRequired field that is missing. |
| 156 | +
|
| 157 | + This verifies the fix for https://github.com/langchain-ai/langchain/issues/35585 |
| 158 | + """ |
| 159 | + model = FakeToolCallingModel( |
| 160 | + tool_calls=[ |
| 161 | + [{"name": "get_weather", "args": {}, "id": "call_1"}], |
| 162 | + [], # No more tool calls, agent should stop |
| 163 | + ] |
| 164 | + ) |
| 165 | + |
| 166 | + agent = create_react_agent( |
| 167 | + model, |
| 168 | + tools=[get_weather], |
| 169 | + state_schema=CustomAgentStateWithNotRequired, |
| 170 | + ) |
| 171 | + |
| 172 | + # Invoke WITHOUT the city field - should work, injecting None |
| 173 | + result = agent.invoke( |
| 174 | + {"messages": [HumanMessage("What's the weather?")]}, |
| 175 | + ) |
| 176 | + |
| 177 | + # Check that the tool was called successfully with None injected |
| 178 | + messages = result["messages"] |
| 179 | + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] |
| 180 | + assert len(tool_messages) == 1 |
| 181 | + assert "No city provided" in tool_messages[0].content |
| 182 | + |
| 183 | + |
| 184 | +@pytest.mark.skipif( |
| 185 | + sys.version_info < (3, 11), |
| 186 | + reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11", |
| 187 | +) |
| 188 | +def test_create_react_agent_injected_state_not_required_field_present(): |
| 189 | + """Test create_react_agent with InjectedState using NotRequired field that IS present.""" |
| 190 | + model = FakeToolCallingModel( |
| 191 | + tool_calls=[ |
| 192 | + [{"name": "get_weather", "args": {}, "id": "call_1"}], |
| 193 | + [], # No more tool calls, agent should stop |
| 194 | + ] |
| 195 | + ) |
| 196 | + |
| 197 | + agent = create_react_agent( |
| 198 | + model, |
| 199 | + tools=[get_weather], |
| 200 | + state_schema=CustomAgentStateWithNotRequired, |
| 201 | + ) |
| 202 | + |
| 203 | + # Invoke WITH the city field |
| 204 | + result = agent.invoke( |
| 205 | + { |
| 206 | + "messages": [HumanMessage("What's the weather?")], |
| 207 | + "city": "San Francisco", |
| 208 | + }, |
| 209 | + ) |
| 210 | + |
| 211 | + # Check that the tool was called successfully |
| 212 | + messages = result["messages"] |
| 213 | + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] |
| 214 | + assert len(tool_messages) == 1 |
| 215 | + assert "San Francisco" in tool_messages[0].content |
| 216 | + |
| 217 | + |
| 218 | +@tool |
| 219 | +def get_weather_optional(city: Annotated[str | None, InjectedState("city")]) -> str: |
| 220 | + """Get weather for a given city (accepts None).""" |
| 221 | + if city is None: |
| 222 | + return "Please provide a city!" |
| 223 | + return f"It's always sunny in {city}!" |
| 224 | + |
| 225 | + |
| 226 | +def test_pydantic_state_with_default_field_missing_works(): |
| 227 | + """Test that Pydantic state with Optional field and default=None works when field is missing. |
| 228 | +
|
| 229 | + This is the workaround suggested in the issue comments - using Pydantic BaseModel |
| 230 | + with `city: Optional[str] = Field(default=None)` instead of TypedDict with NotRequired. |
| 231 | + """ |
| 232 | + model = FakeToolCallingModel( |
| 233 | + tool_calls=[ |
| 234 | + [{"name": "get_weather_optional", "args": {}, "id": "call_1"}], |
| 235 | + [], # No more tool calls, agent should stop |
| 236 | + ] |
| 237 | + ) |
| 238 | + |
| 239 | + agent = create_react_agent( |
| 240 | + model, |
| 241 | + tools=[get_weather_optional], |
| 242 | + state_schema=CustomAgentStatePydanticWithDefault, |
| 243 | + ) |
| 244 | + |
| 245 | + # Invoke WITHOUT the city field - should work because Pydantic provides default |
| 246 | + result = agent.invoke( |
| 247 | + {"messages": [HumanMessage("What's the weather?")]}, |
| 248 | + ) |
| 249 | + |
| 250 | + # Check that the tool was called successfully with None |
| 251 | + messages = result["messages"] |
| 252 | + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] |
| 253 | + assert len(tool_messages) == 1 |
| 254 | + assert "Please provide a city!" in tool_messages[0].content |
| 255 | + |
| 256 | + |
| 257 | +def test_pydantic_state_with_default_field_present_works(): |
| 258 | + """Test that Pydantic state with Optional field works when field IS present.""" |
| 259 | + model = FakeToolCallingModel( |
| 260 | + tool_calls=[ |
| 261 | + [{"name": "get_weather_optional", "args": {}, "id": "call_1"}], |
| 262 | + [], # No more tool calls, agent should stop |
| 263 | + ] |
| 264 | + ) |
| 265 | + |
| 266 | + agent = create_react_agent( |
| 267 | + model, |
| 268 | + tools=[get_weather_optional], |
| 269 | + state_schema=CustomAgentStatePydanticWithDefault, |
| 270 | + ) |
| 271 | + |
| 272 | + # Invoke WITH the city field |
| 273 | + result = agent.invoke( |
| 274 | + { |
| 275 | + "messages": [HumanMessage("What's the weather?")], |
| 276 | + "city": "San Francisco", |
| 277 | + }, |
| 278 | + ) |
| 279 | + |
| 280 | + # Check that the tool was called successfully |
| 281 | + messages = result["messages"] |
| 282 | + tool_messages = [m for m in messages if isinstance(m, ToolMessage)] |
| 283 | + assert len(tool_messages) == 1 |
| 284 | + assert "San Francisco" in tool_messages[0].content |
0 commit comments