Skip to content

Commit 8631f51

Browse files
author
Mateusz
committed
fix(acp): isolate runtimes by client session and sync history
- Key ACPProcessRuntime pool by (project_dir, model, client_session_id) - Replace history_injected with HistoryState (message_count + prefix hash) - On contiguous new messages, send serialize_tail; on divergence, kill and resync - Add ACPTranscriptSerializer.serialize_tail for mid-session context - Extend unit tests for isolation, incremental prompts, and divergence Made-with: Cursor
1 parent 5dd5cd4 commit 8631f51

5 files changed

Lines changed: 373 additions & 36 deletions

File tree

src/connectors/acp_core/base_connector.py

Lines changed: 94 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import asyncio
44
import contextlib
5+
import hashlib
56
import json
67
import logging
78
import os
@@ -20,6 +21,7 @@
2021
ACPSessionUpdate,
2122
AcpStreamPiece,
2223
ACPUpdateContent,
24+
HistoryState,
2325
)
2426
from src.connectors.base import LLMBackend, add_vendor_prefix, strip_vendor_prefix
2527
from src.connectors.contracts import ConnectorChatCompletionsRequest
@@ -107,7 +109,7 @@ def __init__(
107109
self._process_timeout = DEFAULT_PROCESS_TIMEOUT
108110
self._idle_timeout = DEFAULT_IDLE_TIMEOUT
109111
self._runtime_pool_lock = asyncio.Lock()
110-
self._runtimes: dict[tuple[str, str], ACPProcessRuntime] = {}
112+
self._runtimes: dict[tuple[str, str, str], ACPProcessRuntime] = {}
111113

112114
@property
113115
def has_static_credentials(self) -> bool:
@@ -153,19 +155,48 @@ async def _handle_server_request(
153155
def _is_usable_directory(path: Path) -> bool:
154156
return path.exists() and path.is_dir() and os.access(path, os.R_OK)
155157

156-
def _build_runtime_key(self, project_dir: Path, model: str) -> tuple[str, str]:
157-
return (str(project_dir), model)
158+
def _build_runtime_key(
159+
self, project_dir: Path, model: str, client_session_id: str
160+
) -> tuple[str, str, str]:
161+
return (str(project_dir), model, client_session_id)
158162

159-
def _create_runtime(self, project_dir: Path, model: str) -> ACPProcessRuntime:
163+
def _create_runtime(
164+
self, project_dir: Path, model: str, client_session_id: str = "default"
165+
) -> ACPProcessRuntime:
160166
return ACPProcessRuntime(
161167
project_dir=project_dir,
162168
model=model,
169+
client_session_id=client_session_id,
163170
process_lock=asyncio.Lock(),
164171
request_lock=asyncio.Lock(),
165172
cancellation_lock=asyncio.Lock(),
166173
cancellation_event=asyncio.Event(),
167174
)
168175

176+
@staticmethod
177+
def _resolve_client_session_id(request: ConnectorChatCompletionsRequest) -> str:
178+
sid: str | None = None
179+
if request.context is not None and request.context.session_id:
180+
sid = request.context.session_id
181+
if not sid:
182+
raw = getattr(request.request, "session_id", None)
183+
if isinstance(raw, str) and raw.strip():
184+
sid = raw
185+
if isinstance(sid, str) and sid.strip():
186+
return sid.strip()
187+
return "default"
188+
189+
@staticmethod
190+
def _hash_messages_prefix(
191+
messages: Sequence[ChatMessage], end_exclusive: int
192+
) -> str:
193+
if end_exclusive <= 0:
194+
return hashlib.sha256(b"").hexdigest()
195+
slice_msgs = messages[:end_exclusive]
196+
payload = [m.model_dump(mode="json", exclude_none=True) for m in slice_msgs]
197+
canonical = json.dumps(payload, sort_keys=True, separators=(",", ":"))
198+
return hashlib.sha256(canonical.encode("utf-8")).hexdigest()
199+
169200
async def _acquire_runtime(
170201
self, request: ConnectorChatCompletionsRequest
171202
) -> ACPProcessRuntime:
@@ -174,12 +205,17 @@ async def _acquire_runtime(
174205
request.effective_model or self._model,
175206
self.VENDOR_PREFIX,
176207
)
177-
runtime_key = self._build_runtime_key(project_dir, requested_model)
208+
client_session_id = self._resolve_client_session_id(request)
209+
runtime_key = self._build_runtime_key(
210+
project_dir, requested_model, client_session_id
211+
)
178212

179213
async with self._runtime_pool_lock:
180214
runtime = self._runtimes.get(runtime_key)
181215
if runtime is None:
182-
runtime = self._create_runtime(project_dir, requested_model)
216+
runtime = self._create_runtime(
217+
project_dir, requested_model, client_session_id
218+
)
183219
self._runtimes[runtime_key] = runtime
184220

185221
await self._reap_idle_runtime(runtime)
@@ -276,14 +312,14 @@ async def _spawn_process(self, runtime: ACPProcessRuntime) -> None:
276312
runtime.initialized = False
277313
runtime.session_id = None
278314
runtime.message_id = 0
279-
runtime.history_injected = False
315+
runtime.history_state = None
280316
except Exception as exc:
281317
if new_process is not None:
282318
self._cleanup_process(new_process)
283319
runtime.process = None
284320
runtime.initialized = False
285321
runtime.session_id = None
286-
runtime.history_injected = False
322+
runtime.history_state = None
287323
raise APIConnectionError(
288324
message=f"Failed to start ACP process: {exc}",
289325
details={
@@ -323,7 +359,7 @@ def _cleanup_runtime_state(
323359
runtime.session_id = None
324360
runtime.message_id = 0
325361
runtime.last_activity = 0.0
326-
runtime.history_injected = False
362+
runtime.history_state = None
327363

328364
def _cleanup_process(self, process: subprocess.Popen[bytes] | None = None) -> None:
329365
if process is None:
@@ -767,14 +803,57 @@ async def _prepare_prompt_request_locked(
767803
await self._spawn_process(runtime)
768804
await self._initialize_runtime(runtime)
769805

770-
if not runtime.history_injected:
771-
user_message = ACPTranscriptSerializer.serialize(request.processed_messages)
772-
runtime.history_injected = True
806+
messages = list(request.processed_messages)
807+
if not messages:
808+
raise BackendError(message="No messages found in request")
809+
810+
state = runtime.history_state
811+
new_history_state: HistoryState
812+
user_message: str
813+
814+
if state is None:
815+
user_message = ACPTranscriptSerializer.serialize(messages)
816+
new_history_state = HistoryState(
817+
message_count=len(messages),
818+
prefix_hash=self._hash_messages_prefix(messages, len(messages)),
819+
)
773820
else:
774-
user_message = self._extract_user_message_as_string(
775-
request.processed_messages
821+
n = state.message_count
822+
prefix_hash = state.prefix_hash
823+
diverged = (
824+
len(messages) < n
825+
or self._hash_messages_prefix(messages, n) != prefix_hash
776826
)
777827

828+
if diverged:
829+
if logger.isEnabledFor(logging.INFO):
830+
logger.info(
831+
"ACP history diverged or shrank; resetting agent process "
832+
"(project=%s model=%s client_session=%s)",
833+
runtime.project_dir,
834+
runtime.model,
835+
runtime.client_session_id,
836+
)
837+
await self._kill_runtime(runtime)
838+
await self._spawn_process(runtime)
839+
await self._initialize_runtime(runtime)
840+
user_message = ACPTranscriptSerializer.serialize(messages)
841+
new_history_state = HistoryState(
842+
message_count=len(messages),
843+
prefix_hash=self._hash_messages_prefix(messages, len(messages)),
844+
)
845+
elif len(messages) == n:
846+
user_message = self._extract_user_message_as_string(messages)
847+
new_history_state = state
848+
else:
849+
user_message = ACPTranscriptSerializer.serialize_tail(messages, n)
850+
if not user_message.strip():
851+
user_message = self._extract_user_message_as_string(messages)
852+
new_history_state = HistoryState(
853+
message_count=len(messages),
854+
prefix_hash=self._hash_messages_prefix(messages, len(messages)),
855+
)
856+
778857
if not user_message:
779858
raise BackendError(message="No user message found in request")
780859

@@ -792,6 +871,7 @@ async def _prepare_prompt_request_locked(
792871
"session/prompt",
793872
prompt_params,
794873
)
874+
runtime.history_state = new_history_state
795875
return prompt_request_id, requested_model
796876

797877
async def _stream_response_with_lock(

src/connectors/acp_core/transcript.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,73 @@ def serialize(messages: Sequence[ChatMessage | dict[str, Any] | str | Any]) -> s
8181

8282
return "\n".join(lines)
8383

84+
@staticmethod
85+
def serialize_tail(
86+
messages: Sequence[ChatMessage | dict[str, Any] | str | Any],
87+
start_index: int,
88+
) -> str:
89+
"""Serialize messages from ``start_index`` through the final user turn.
90+
91+
Used when the ACP agent already saw messages ``[:start_index]`` and new
92+
messages were appended (e.g. non-ACP turns in between).
93+
"""
94+
if start_index <= 0:
95+
return ACPTranscriptSerializer.serialize(messages)
96+
if not messages or start_index >= len(messages):
97+
return ""
98+
99+
last_user_idx = -1
100+
for i in range(len(messages) - 1, -1, -1):
101+
if ACPTranscriptSerializer._get_role(messages[i]) == "user":
102+
last_user_idx = i
103+
break
104+
105+
if last_user_idx == -1:
106+
return ""
107+
108+
last_user_msg = ACPTranscriptSerializer._get_content(messages[last_user_idx])
109+
if last_user_idx < start_index:
110+
return last_user_msg
111+
112+
history_msgs = list(messages[start_index:last_user_idx])
113+
if not history_msgs:
114+
return last_user_msg
115+
116+
lines = [
117+
"[System Note: Additional conversation occurred since your last response. "
118+
"Here is the new context:]",
119+
"",
120+
"--- New Messages ---",
121+
]
122+
123+
for msg in history_msgs:
124+
role = ACPTranscriptSerializer._get_role(msg)
125+
content = ACPTranscriptSerializer._get_content(msg)
126+
127+
if role == "system":
128+
lines.append(f"**System:** {content}")
129+
elif role == "user":
130+
lines.append(f"**User:** {content}")
131+
elif role == "assistant":
132+
lines.append(f"**Assistant:** {content}")
133+
134+
tool_calls = ACPTranscriptSerializer._get_tool_calls(msg)
135+
for tc in tool_calls:
136+
func_name = tc.get("function", {}).get("name", "unknown")
137+
func_args = tc.get("function", {}).get("arguments", "{}")
138+
lines.append(f"*Tool Call (`{func_name}`)*: `{func_args}`")
139+
elif role == "tool":
140+
lines.append(f"*Tool Result*: `{content}`")
141+
else:
142+
lines.append(f"**{role.capitalize()}:** {content}")
143+
144+
lines.append("------------------------")
145+
lines.append("")
146+
lines.append("[Current Request]")
147+
lines.append(last_user_msg)
148+
149+
return "\n".join(lines)
150+
84151
@staticmethod
85152
def _get_role(msg: Any) -> str:
86153
if isinstance(msg, ChatMessage):

src/connectors/acp_core/types.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,27 @@ class AcpStreamPiece:
8282
reasoning_content: str | None = None
8383

8484

85+
@dataclass(frozen=True, slots=True)
86+
class HistoryState:
87+
"""Tracks how much of ``processed_messages`` has been applied to the ACP agent."""
88+
89+
message_count: int
90+
prefix_hash: str
91+
92+
8593
@dataclass(slots=True)
8694
class ACPProcessRuntime:
87-
"""Live ACP runtime bound to a project directory and model."""
95+
"""Live ACP runtime bound to a project directory, model, and client session."""
8896

8997
project_dir: Path
9098
model: str
99+
client_session_id: str = "default"
91100
process: Any | None = None
92101
session_id: str | None = None
93102
initialized: bool = False
94103
message_id: int = 0
95104
last_activity: float = 0.0
96-
history_injected: bool = False
105+
history_state: HistoryState | None = None
97106
process_lock: Any = field(default=None)
98107
request_lock: Any = field(default=None)
99108
cancellation_lock: Any = field(default=None)

0 commit comments

Comments
 (0)