Skip to content

Commit 2d766c4

Browse files
feat(session): add dirty flag to skip unnecessary agent state persistence (#1803)
Co-authored-by: Strands Agent <217235299+strands-agent@users.noreply.github.com>
1 parent 697e55c commit 2d766c4

6 files changed

Lines changed: 436 additions & 1 deletion

File tree

src/strands/interrupt.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ class _InterruptState:
5252
interrupts: dict[str, Interrupt] = field(default_factory=dict)
5353
context: dict[str, Any] = field(default_factory=dict)
5454
activated: bool = False
55+
_version: int = field(default=0, compare=False, repr=False)
5556

5657
def activate(self) -> None:
5758
"""Activate the interrupt state."""
5859
self.activated = True
60+
self._version += 1
5961

6062
def deactivate(self) -> None:
6163
"""Deacitvate the interrupt state.
@@ -65,6 +67,7 @@ def deactivate(self) -> None:
6567
self.interrupts = {}
6668
self.context = {}
6769
self.activated = False
70+
self._version += 1
6871

6972
def resume(self, prompt: "AgentInput") -> None:
7073
"""Configure the interrupt state if resuming from an interrupt event.
@@ -100,10 +103,27 @@ def resume(self, prompt: "AgentInput") -> None:
100103
self.interrupts[interrupt_id].response = interrupt_response
101104

102105
self.context["responses"] = contents
106+
self._version += 1
107+
108+
def _get_version(self) -> int:
109+
"""Get the current version number of the interrupt state.
110+
111+
The version is incremented each time activate(), deactivate(), or resume() is called.
112+
Consumers can compare versions to detect changes without requiring
113+
explicit dirty flag clearing.
114+
115+
Returns:
116+
The current version number.
117+
"""
118+
return self._version
103119

104120
def to_dict(self) -> dict[str, Any]:
105121
"""Serialize to dict for session management."""
106-
return asdict(self)
122+
return {
123+
"interrupts": {k: v.to_dict() for k, v in self.interrupts.items()},
124+
"context": self.context,
125+
"activated": self.activated,
126+
}
107127

108128
@classmethod
109129
def from_dict(cls, data: dict[str, Any]) -> "_InterruptState":

src/strands/session/repository_session_manager.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Repository session manager implementation."""
22

3+
import copy
34
import logging
45
from typing import TYPE_CHECKING, Any
56

@@ -59,6 +60,9 @@ def __init__(
5960
# Keep track of the latest message of each agent in case we need to redact it.
6061
self._latest_agent_message: dict[str, SessionMessage | None] = {}
6162

63+
# Track the previously synced internal state for each agent to detect changes.
64+
self._last_synced_internal_state: dict[str, dict[str, Any]] = {}
65+
6266
def append_message(self, message: Message, agent: "Agent", **kwargs: Any) -> None:
6367
"""Append a message to the agent's session.
6468
@@ -95,15 +99,66 @@ def redact_latest_message(self, redact_message: Message, agent: "Agent", **kwarg
9599
def sync_agent(self, agent: "Agent", **kwargs: Any) -> None:
96100
"""Serialize and update the agent into the session repository.
97101
102+
Only updates the agent if state has been modified or internal state has changed.
103+
This optimization reduces unnecessary I/O operations when the agent processes
104+
messages without modifying its state.
105+
98106
Args:
99107
agent: Agent to sync to the session.
100108
**kwargs: Additional keyword arguments for future extensibility.
101109
"""
110+
# Get current versions and conversation manager state
111+
current_state_version = agent.state._get_version()
112+
current_interrupt_state_version = agent._interrupt_state._get_version()
113+
current_conversation_manager_state = agent.conversation_manager.get_state()
114+
115+
# Check if we have a previous state to compare against
116+
last_synced = self._last_synced_internal_state.get(agent.agent_id)
117+
118+
# Determine if we need to update by comparing versions
119+
if last_synced is None:
120+
# First sync for this agent - always update
121+
state_changed = True
122+
internal_state_changed = True
123+
conversation_manager_state_changed = True
124+
else:
125+
state_changed = current_state_version != last_synced.get("state_version")
126+
internal_state_changed = current_interrupt_state_version != last_synced.get("interrupt_state_version")
127+
conversation_manager_state_changed = (
128+
current_conversation_manager_state != last_synced.get("conversation_manager_state")
129+
)
130+
131+
if not state_changed and not internal_state_changed and not conversation_manager_state_changed:
132+
logger.debug(
133+
"agent_id=<%s> | session_id=<%s> | skipping sync, no changes detected",
134+
agent.agent_id,
135+
self.session_id,
136+
)
137+
return
138+
139+
logger.debug(
140+
"agent_id=<%s> | session_id=<%s> | state_changed=<%s>, internal_state_changed=<%s>, "
141+
"conversation_manager_state_changed=<%s> | syncing agent",
142+
agent.agent_id,
143+
self.session_id,
144+
state_changed,
145+
internal_state_changed,
146+
conversation_manager_state_changed,
147+
)
148+
149+
# Perform the update
102150
self.session_repository.update_agent(
103151
self.session_id,
104152
SessionAgent.from_agent(agent),
105153
)
106154

155+
# Update tracked versions after successful sync
156+
self._last_synced_internal_state[agent.agent_id] = {
157+
"state_version": current_state_version,
158+
"interrupt_state_version": current_interrupt_state_version,
159+
"conversation_manager_state": copy.deepcopy(current_conversation_manager_state),
160+
}
161+
107162
def initialize(self, agent: "Agent", **kwargs: Any) -> None:
108163
"""Initialize an agent with a session.
109164

src/strands/types/json_dict.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class JSONSerializableDict:
1515
def __init__(self, initial_state: dict[str, Any] | None = None):
1616
"""Initialize JSONSerializableDict."""
1717
self._data: dict[str, Any]
18+
self._version: int = 0
1819
if initial_state:
1920
self._validate_json_serializable(initial_state)
2021
self._data = copy.deepcopy(initial_state)
@@ -34,6 +35,7 @@ def set(self, key: str, value: Any) -> None:
3435
self._validate_key(key)
3536
self._validate_json_serializable(value)
3637
self._data[key] = copy.deepcopy(value)
38+
self._version += 1
3739

3840
def get(self, key: str | None = None) -> Any:
3941
"""Get a value or entire data.
@@ -57,6 +59,19 @@ def delete(self, key: str) -> None:
5759
"""
5860
self._validate_key(key)
5961
self._data.pop(key, None)
62+
self._version += 1
63+
64+
def _get_version(self) -> int:
65+
"""Get the current version number of the store.
66+
67+
The version is incremented each time set() or delete() is called.
68+
Consumers can compare versions to detect changes without requiring
69+
explicit dirty flag clearing.
70+
71+
Returns:
72+
The current version number.
73+
"""
74+
return self._version
6075

6176
def _validate_key(self, key: str) -> None:
6277
"""Validate that a key is valid.

tests/strands/session/test_repository_session_manager.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,3 +595,219 @@ def test_fix_broken_tool_use_does_not_affect_normal_conversations(session_manage
595595

596596
# Should remain unchanged
597597
assert fixed_messages == messages
598+
599+
600+
# ============================================================================
601+
# Conditional Sync Tests
602+
# ============================================================================
603+
604+
605+
def test_sync_agent_skips_update_when_state_not_dirty_and_internal_state_unchanged(mock_repository):
606+
"""Test that sync_agent() skips update_agent() when state is not dirty and internal state unchanged."""
607+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
608+
609+
# Create and initialize agent
610+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
611+
612+
# Track update_agent calls
613+
update_agent_calls = []
614+
original_update_agent = mock_repository.update_agent
615+
616+
def tracking_update_agent(session_id, session_agent):
617+
update_agent_calls.append((session_id, session_agent))
618+
return original_update_agent(session_id, session_agent)
619+
620+
mock_repository.update_agent = tracking_update_agent
621+
622+
# First sync should update (to establish baseline)
623+
session_manager.sync_agent(agent)
624+
assert len(update_agent_calls) == 1
625+
626+
# Clear tracking
627+
update_agent_calls.clear()
628+
629+
# Second sync without changes should skip update
630+
session_manager.sync_agent(agent)
631+
assert len(update_agent_calls) == 0
632+
633+
634+
def test_sync_agent_calls_update_when_state_is_dirty(mock_repository):
635+
"""Test that sync_agent() calls update_agent() when agent.state is dirty."""
636+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
637+
638+
# Create and initialize agent
639+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
640+
641+
# Track update_agent calls
642+
update_agent_calls = []
643+
original_update_agent = mock_repository.update_agent
644+
645+
def tracking_update_agent(session_id, session_agent):
646+
update_agent_calls.append((session_id, session_agent))
647+
return original_update_agent(session_id, session_agent)
648+
649+
mock_repository.update_agent = tracking_update_agent
650+
651+
# First sync to establish baseline
652+
session_manager.sync_agent(agent)
653+
update_agent_calls.clear()
654+
655+
# Modify state (makes it dirty)
656+
agent.state.set("key", "value")
657+
658+
# Sync should call update_agent because state is dirty
659+
session_manager.sync_agent(agent)
660+
assert len(update_agent_calls) == 1
661+
662+
663+
def test_sync_agent_calls_update_when_internal_state_changed(mock_repository):
664+
"""Test that sync_agent() calls update_agent() when internal state (interrupt_state) is dirty."""
665+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
666+
667+
# Create and initialize agent
668+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
669+
670+
# Track update_agent calls
671+
update_agent_calls = []
672+
original_update_agent = mock_repository.update_agent
673+
674+
def tracking_update_agent(session_id, session_agent):
675+
update_agent_calls.append((session_id, session_agent))
676+
return original_update_agent(session_id, session_agent)
677+
678+
mock_repository.update_agent = tracking_update_agent
679+
680+
# First sync to establish baseline
681+
session_manager.sync_agent(agent)
682+
update_agent_calls.clear()
683+
684+
# Modify internal state (activate interrupt state which sets dirty flag)
685+
agent._interrupt_state.activate()
686+
687+
# Sync should call update_agent because internal state is dirty
688+
session_manager.sync_agent(agent)
689+
assert len(update_agent_calls) == 1
690+
691+
692+
def test_sync_agent_calls_update_when_conversation_manager_state_changed(mock_repository):
693+
"""Test that sync_agent() calls update_agent() when conversation manager state changed."""
694+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
695+
696+
# Create and initialize agent
697+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
698+
699+
# Track update_agent calls
700+
update_agent_calls = []
701+
original_update_agent = mock_repository.update_agent
702+
703+
def tracking_update_agent(session_id, session_agent):
704+
update_agent_calls.append((session_id, session_agent))
705+
return original_update_agent(session_id, session_agent)
706+
707+
mock_repository.update_agent = tracking_update_agent
708+
709+
# First sync to establish baseline
710+
session_manager.sync_agent(agent)
711+
update_agent_calls.clear()
712+
713+
# Modify conversation manager state
714+
agent.conversation_manager.removed_message_count = 5
715+
716+
# Sync should call update_agent because conversation manager state changed
717+
session_manager.sync_agent(agent)
718+
assert len(update_agent_calls) == 1
719+
720+
721+
def test_sync_agent_tracks_version_after_successful_sync(mock_repository):
722+
"""Test that sync_agent() tracks version after successful sync."""
723+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
724+
725+
# Create and initialize agent
726+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
727+
728+
# First sync to establish baseline
729+
session_manager.sync_agent(agent)
730+
initial_version = agent.state._get_version()
731+
732+
# Modify state (increments version)
733+
agent.state.set("key", "value")
734+
assert agent.state._get_version() == initial_version + 1
735+
736+
# Track update_agent calls
737+
update_agent_calls = []
738+
original_update_agent = mock_repository.update_agent
739+
740+
def tracking_update_agent(session_id, session_agent):
741+
update_agent_calls.append((session_id, session_agent))
742+
return original_update_agent(session_id, session_agent)
743+
744+
mock_repository.update_agent = tracking_update_agent
745+
746+
# Sync should update because version changed
747+
session_manager.sync_agent(agent)
748+
assert len(update_agent_calls) == 1
749+
750+
# Second sync without changes should skip
751+
update_agent_calls.clear()
752+
session_manager.sync_agent(agent)
753+
assert len(update_agent_calls) == 0
754+
755+
756+
def test_sync_agent_retries_on_failure(mock_repository):
757+
"""Test that sync_agent() retries on next call if update_agent() fails."""
758+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
759+
760+
# Create and initialize agent
761+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
762+
763+
# First sync to establish baseline
764+
session_manager.sync_agent(agent)
765+
766+
# Modify state (increments version)
767+
agent.state.set("key", "value")
768+
769+
# Make update_agent fail
770+
def failing_update_agent(session_id, session_agent):
771+
raise SessionException("Update failed")
772+
773+
mock_repository.update_agent = failing_update_agent
774+
775+
# Sync should fail
776+
with pytest.raises(SessionException, match="Update failed"):
777+
session_manager.sync_agent(agent)
778+
779+
# Restore working update_agent
780+
update_agent_calls = []
781+
original_update_agent = MockedSessionRepository.update_agent
782+
783+
def tracking_update_agent(self, session_id, session_agent):
784+
update_agent_calls.append((session_id, session_agent))
785+
return original_update_agent(self, session_id, session_agent)
786+
787+
mock_repository.update_agent = lambda sid, sa: tracking_update_agent(mock_repository, sid, sa)
788+
789+
# Retry should work because version wasn't updated on failure
790+
session_manager.sync_agent(agent)
791+
assert len(update_agent_calls) == 1
792+
793+
794+
def test_sync_agent_first_sync_always_updates(mock_repository):
795+
"""Test that the first sync_agent() call always updates (no previous state to compare)."""
796+
session_manager = RepositorySessionManager(session_id="test-session", session_repository=mock_repository)
797+
798+
# Create and initialize agent
799+
agent = Agent(agent_id="test-agent", session_manager=session_manager)
800+
801+
# Track update_agent calls
802+
update_agent_calls = []
803+
original_update_agent = mock_repository.update_agent
804+
805+
def tracking_update_agent(session_id, session_agent):
806+
update_agent_calls.append((session_id, session_agent))
807+
return original_update_agent(session_id, session_agent)
808+
809+
mock_repository.update_agent = tracking_update_agent
810+
811+
# First sync should always update (no previous state)
812+
session_manager.sync_agent(agent)
813+
assert len(update_agent_calls) == 1

0 commit comments

Comments
 (0)