Skip to content

Commit f093702

Browse files
fix(prebuilt): handle injected NotRequired keys (#7392)
Resolves langchain-ai/langchain#35585 This would previously raise KeyError: ```python from typing import Annotated from langchain_core.tools import tool from langchain.agents import create_agent from typing_extensions import NotRequired from langgraph.prebuilt import InjectedState from langchain.agents import AgentState class CustomAgentState(AgentState): city: NotRequired[str] @tool def get_weather(city: Annotated[str | None, InjectedState("city")] = None) -> str: """Get weather for a given city.""" if city is None: city = "Boston" return f"It's always sunny in {city}!" agent = create_agent( model="claude-sonnet-4-6", tools=[get_weather], system_prompt="You are a helpful assistant", state_schema=CustomAgentState, ) input_message = { "role": "user", "content": "What's the weather?", } result = agent.invoke({"messages": [input_message]}) for m in result["messages"]: m.pretty_print() ``` --------- Co-authored-by: Sydney Runkle <sydneymarierunkle@gmail.com>
1 parent 51cbdbd commit f093702

2 files changed

Lines changed: 303 additions & 7 deletions

File tree

libs/prebuilt/langgraph/prebuilt/tool_node.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ def my_tool(
614614
store: str | None
615615
runtime: str | None
616616
all_injected_keys: set[str]
617+
_optional_state_args: set[str]
617618

618619

619620
class ToolNode(RunnableCallable):
@@ -1333,7 +1334,7 @@ def _inject_tool_args(
13331334
return tool_call
13341335

13351336
tool_call_copy: ToolCall = copy(tool_call)
1336-
injected_args = {}
1337+
injected_args: dict[str, Any] = {}
13371338

13381339
# Inject state
13391340
if injected.state:
@@ -1361,14 +1362,20 @@ def _inject_tool_args(
13611362
# Extract state values
13621363
if isinstance(state, dict):
13631364
for tool_arg, state_field in injected.state.items():
1364-
injected_args[tool_arg] = (
1365-
state[state_field] if state_field else state
1366-
)
1365+
if not state_field:
1366+
injected_args[tool_arg] = state
1367+
elif state_field in state:
1368+
injected_args[tool_arg] = state[state_field]
1369+
elif tool_arg not in injected._optional_state_args:
1370+
raise KeyError(state_field)
13671371
else:
13681372
for tool_arg, state_field in injected.state.items():
1369-
injected_args[tool_arg] = (
1370-
getattr(state, state_field) if state_field else state
1371-
)
1373+
if not state_field:
1374+
injected_args[tool_arg] = state
1375+
elif hasattr(state, state_field):
1376+
injected_args[tool_arg] = getattr(state, state_field)
1377+
elif tool_arg not in injected._optional_state_args:
1378+
raise AttributeError(state_field)
13721379

13731380
# Inject store
13741381
if injected.store:
@@ -1859,6 +1866,7 @@ def _get_all_injected_args(tool: BaseTool) -> _InjectedArgs:
18591866
store_arg: str | None = None
18601867
runtime_arg: str | None = None
18611868
all_injected_keys: set[str] = set()
1869+
_optional_state_args: set[str] = set()
18621870

18631871
for name, type_ in all_annotations.items():
18641872
# Track all InjectedToolArg-annotated params (including custom subclasses)
@@ -1873,6 +1881,9 @@ def _get_all_injected_args(tool: BaseTool) -> _InjectedArgs:
18731881
if state_inj := _get_injection_from_type(type_, InjectedState):
18741882
if isinstance(state_inj, InjectedState) and state_inj.field:
18751883
state_args[name] = state_inj.field
1884+
field_info = full_schema.model_fields.get(name)
1885+
if field_info and not field_info.is_required():
1886+
_optional_state_args.add(name)
18761887
else:
18771888
state_args[name] = None
18781889

@@ -1889,4 +1900,5 @@ def _get_all_injected_args(tool: BaseTool) -> _InjectedArgs:
18891900
store=store_arg,
18901901
runtime=runtime_arg,
18911902
all_injected_keys=all_injected_keys,
1903+
_optional_state_args=_optional_state_args,
18921904
)
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
"""Test InjectedState with NotRequired state fields.
2+
3+
This tests the fix for https://github.com/langchain-ai/langchain/issues/35585
4+
5+
When using InjectedState(<field>) on a tool parameter, and the referenced field is
6+
declared as NotRequired in the custom state schema, the ToolNode should gracefully
7+
handle missing fields by injecting None instead of raising KeyError.
8+
"""
9+
10+
import sys
11+
from typing import Annotated
12+
13+
import pytest
14+
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, ToolMessage
15+
from langchain_core.tools import tool
16+
from langgraph.graph.message import add_messages
17+
from pydantic import BaseModel, Field
18+
from typing_extensions import NotRequired
19+
20+
from langgraph.prebuilt import InjectedState, ToolNode, create_react_agent
21+
from langgraph.prebuilt.chat_agent_executor import AgentState
22+
23+
from .model import FakeToolCallingModel
24+
25+
26+
class CustomAgentStateWithNotRequired(AgentState):
27+
"""Custom state with a NotRequired field (TypedDict style)."""
28+
29+
city: NotRequired[str]
30+
31+
32+
class CustomAgentStatePydanticWithDefault(BaseModel):
33+
"""Custom state with Optional field and default (Pydantic style)."""
34+
35+
messages: Annotated[list[AnyMessage], add_messages]
36+
remaining_steps: int = Field(default=10)
37+
city: str | None = Field(default=None)
38+
39+
40+
@tool
41+
def get_weather(city: Annotated[str | None, InjectedState("city")] = None) -> str:
42+
"""Get weather for a given city."""
43+
if city is None:
44+
return "No city provided"
45+
return f"It's always sunny in {city}!"
46+
47+
48+
def _create_mock_runtime(
49+
state: dict | None = None,
50+
store=None,
51+
):
52+
"""Create a mock Runtime for testing ToolNode directly."""
53+
from unittest.mock import Mock
54+
55+
from langgraph.runtime import Runtime
56+
57+
mock_runtime = Mock(spec=Runtime)
58+
mock_runtime.context = {}
59+
return mock_runtime
60+
61+
62+
def _create_config_with_runtime(store=None, state=None):
63+
"""Create a RunnableConfig with mocked runtime for direct ToolNode testing."""
64+
from langgraph.prebuilt.tool_node import ToolRuntime
65+
66+
tool_runtime = ToolRuntime(
67+
state=state or {},
68+
config={},
69+
context={},
70+
store=store,
71+
stream_writer=None,
72+
tool_call_id="test_id",
73+
)
74+
return {
75+
"configurable": {
76+
"__pregel_runtime": _create_mock_runtime(),
77+
"__tool_runtime__": tool_runtime,
78+
}
79+
}
80+
81+
82+
@pytest.mark.skipif(
83+
sys.version_info < (3, 11),
84+
reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11",
85+
)
86+
def test_injected_state_not_required_field_missing_injects_none():
87+
"""Test that InjectedState with NotRequired field injects None when field is missing.
88+
89+
This verifies the fix for https://github.com/langchain-ai/langchain/issues/35585
90+
"""
91+
tool_node = ToolNode([get_weather])
92+
93+
tool_call = {
94+
"name": "get_weather",
95+
"args": {},
96+
"id": "call_1",
97+
"type": "tool_call",
98+
}
99+
ai_msg = AIMessage("Let me check the weather", tool_calls=[tool_call])
100+
101+
# State WITHOUT the "city" field - should inject None instead of raising KeyError
102+
state_without_city: CustomAgentStateWithNotRequired = {
103+
"messages": [HumanMessage("What's the weather?"), ai_msg],
104+
}
105+
106+
result = tool_node.invoke(
107+
state_without_city,
108+
config=_create_config_with_runtime(state=state_without_city),
109+
)
110+
111+
assert len(result["messages"]) == 1
112+
tool_msg = result["messages"][0]
113+
assert isinstance(tool_msg, ToolMessage)
114+
assert "No city provided" in tool_msg.content
115+
116+
117+
@pytest.mark.skipif(
118+
sys.version_info < (3, 11),
119+
reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11",
120+
)
121+
def test_injected_state_not_required_field_present_works():
122+
"""Test that InjectedState with NotRequired field works when field IS present."""
123+
tool_node = ToolNode([get_weather])
124+
125+
tool_call = {
126+
"name": "get_weather",
127+
"args": {},
128+
"id": "call_1",
129+
"type": "tool_call",
130+
}
131+
ai_msg = AIMessage("Let me check the weather", tool_calls=[tool_call])
132+
133+
# State WITH the "city" field - this should work
134+
state_with_city: CustomAgentStateWithNotRequired = {
135+
"messages": [HumanMessage("What's the weather?"), ai_msg],
136+
"city": "San Francisco",
137+
}
138+
139+
result = tool_node.invoke(
140+
state_with_city,
141+
config=_create_config_with_runtime(state=state_with_city),
142+
)
143+
144+
assert len(result["messages"]) == 1
145+
tool_msg = result["messages"][0]
146+
assert isinstance(tool_msg, ToolMessage)
147+
assert "San Francisco" in tool_msg.content
148+
149+
150+
@pytest.mark.skipif(
151+
sys.version_info < (3, 11),
152+
reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11",
153+
)
154+
def test_create_react_agent_injected_state_not_required_field_missing():
155+
"""Test create_react_agent with InjectedState using NotRequired field that is missing.
156+
157+
This verifies the fix for https://github.com/langchain-ai/langchain/issues/35585
158+
"""
159+
model = FakeToolCallingModel(
160+
tool_calls=[
161+
[{"name": "get_weather", "args": {}, "id": "call_1"}],
162+
[], # No more tool calls, agent should stop
163+
]
164+
)
165+
166+
agent = create_react_agent(
167+
model,
168+
tools=[get_weather],
169+
state_schema=CustomAgentStateWithNotRequired,
170+
)
171+
172+
# Invoke WITHOUT the city field - should work, injecting None
173+
result = agent.invoke(
174+
{"messages": [HumanMessage("What's the weather?")]},
175+
)
176+
177+
# Check that the tool was called successfully with None injected
178+
messages = result["messages"]
179+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
180+
assert len(tool_messages) == 1
181+
assert "No city provided" in tool_messages[0].content
182+
183+
184+
@pytest.mark.skipif(
185+
sys.version_info < (3, 11),
186+
reason="InjectedState field extraction from Optional[Annotated[...]] not supported on Python <3.11",
187+
)
188+
def test_create_react_agent_injected_state_not_required_field_present():
189+
"""Test create_react_agent with InjectedState using NotRequired field that IS present."""
190+
model = FakeToolCallingModel(
191+
tool_calls=[
192+
[{"name": "get_weather", "args": {}, "id": "call_1"}],
193+
[], # No more tool calls, agent should stop
194+
]
195+
)
196+
197+
agent = create_react_agent(
198+
model,
199+
tools=[get_weather],
200+
state_schema=CustomAgentStateWithNotRequired,
201+
)
202+
203+
# Invoke WITH the city field
204+
result = agent.invoke(
205+
{
206+
"messages": [HumanMessage("What's the weather?")],
207+
"city": "San Francisco",
208+
},
209+
)
210+
211+
# Check that the tool was called successfully
212+
messages = result["messages"]
213+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
214+
assert len(tool_messages) == 1
215+
assert "San Francisco" in tool_messages[0].content
216+
217+
218+
@tool
219+
def get_weather_optional(city: Annotated[str | None, InjectedState("city")]) -> str:
220+
"""Get weather for a given city (accepts None)."""
221+
if city is None:
222+
return "Please provide a city!"
223+
return f"It's always sunny in {city}!"
224+
225+
226+
def test_pydantic_state_with_default_field_missing_works():
227+
"""Test that Pydantic state with Optional field and default=None works when field is missing.
228+
229+
This is the workaround suggested in the issue comments - using Pydantic BaseModel
230+
with `city: Optional[str] = Field(default=None)` instead of TypedDict with NotRequired.
231+
"""
232+
model = FakeToolCallingModel(
233+
tool_calls=[
234+
[{"name": "get_weather_optional", "args": {}, "id": "call_1"}],
235+
[], # No more tool calls, agent should stop
236+
]
237+
)
238+
239+
agent = create_react_agent(
240+
model,
241+
tools=[get_weather_optional],
242+
state_schema=CustomAgentStatePydanticWithDefault,
243+
)
244+
245+
# Invoke WITHOUT the city field - should work because Pydantic provides default
246+
result = agent.invoke(
247+
{"messages": [HumanMessage("What's the weather?")]},
248+
)
249+
250+
# Check that the tool was called successfully with None
251+
messages = result["messages"]
252+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
253+
assert len(tool_messages) == 1
254+
assert "Please provide a city!" in tool_messages[0].content
255+
256+
257+
def test_pydantic_state_with_default_field_present_works():
258+
"""Test that Pydantic state with Optional field works when field IS present."""
259+
model = FakeToolCallingModel(
260+
tool_calls=[
261+
[{"name": "get_weather_optional", "args": {}, "id": "call_1"}],
262+
[], # No more tool calls, agent should stop
263+
]
264+
)
265+
266+
agent = create_react_agent(
267+
model,
268+
tools=[get_weather_optional],
269+
state_schema=CustomAgentStatePydanticWithDefault,
270+
)
271+
272+
# Invoke WITH the city field
273+
result = agent.invoke(
274+
{
275+
"messages": [HumanMessage("What's the weather?")],
276+
"city": "San Francisco",
277+
},
278+
)
279+
280+
# Check that the tool was called successfully
281+
messages = result["messages"]
282+
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
283+
assert len(tool_messages) == 1
284+
assert "San Francisco" in tool_messages[0].content

0 commit comments

Comments
 (0)