1616from pathlib import Path
1717from typing import Callable , Iterator , Optional , Sequence , TYPE_CHECKING
1818
19- from ..exceptions import ForkError , ProcessBranchError
19+ from ..exceptions import ForkError , MemoryProtectError , ProcessBranchError
20+ from ..memory ._mprotect import (
21+ MemoryRegion ,
22+ PROT_READ ,
23+ PROT_WRITE ,
24+ mprotect as _mprotect ,
25+ )
2026from . import _cgroup
2127from ._namespace import setup_user_ns , bind_mount
2228
@@ -43,6 +49,7 @@ def __init__(
4349 close_fds : bool = False ,
4450 limits : ResourceLimits | None = None ,
4551 parent_cgroup : Path | None = None ,
52+ protected_regions : Sequence [tuple [int , int ]] | None = None ,
4653 ):
4754 """
4855 Args:
@@ -56,16 +63,21 @@ def __init__(
5663 parent_cgroup: Optional parent cgroup directory. When given,
5764 the child's scope is created under this directory instead
5865 of the process's own cgroup, enabling hierarchical nesting.
66+ protected_regions: Optional list of (addr, size) tuples.
67+ After fork, these regions are marked read-only in the
68+ parent via mprotect(2) to enforce the branch invariant.
5969 """
6070 self ._target = target
6171 self ._workspace = workspace
6272 self ._isolate = isolate
6373 self ._close_fds = close_fds
6474 self ._limits = limits
6575 self ._parent_cgroup = parent_cgroup
76+ self ._protected_regions = protected_regions
6677 self ._pid : Optional [int ] = None
6778 self ._exited = False
6879 self ._cgroup_scope : Optional [Path ] = None
80+ self ._memory_regions : list [MemoryRegion ] = []
6981
7082 @property
7183 def pid (self ) -> int :
@@ -91,6 +103,9 @@ def alive(self) -> bool:
91103 def wait (self , timeout : Optional [float ] = None ) -> None :
92104 """Wait for the child process to exit.
93105
106+ On completion (success or failure), restores write access to any
107+ protected memory regions so the parent can proceed.
108+
94109 Args:
95110 timeout: Maximum seconds to wait (None = wait forever).
96111
@@ -99,6 +114,7 @@ def wait(self, timeout: Optional[float] = None) -> None:
99114 TimeoutError: If the child doesn't exit within timeout.
100115 """
101116 exit_code = self ._wait_raw (timeout )
117+ self ._restore_memory_regions ()
102118 if exit_code != 0 :
103119 raise ProcessBranchError (
104120 f"Child { self ._pid } exited with status { exit_code } "
@@ -132,12 +148,15 @@ def _wait_raw(self, timeout: Optional[float] = None) -> int:
132148 def abort (self , timeout : float = 5.0 ) -> None :
133149 """Abort the child and all its descendants.
134150
135- Uses both cgroup kill (catches escapees) and killpg (POSIX standard).
136- Escalates SIGTERM -> SIGKILL after timeout.
151+ Restores write access to protected memory regions, then kills the
152+ child. Uses both cgroup kill (catches escapees) and killpg (POSIX
153+ standard). Escalates SIGTERM -> SIGKILL after timeout.
137154 """
138155 if self ._pid is None or self ._exited :
139156 return
140157
158+ self ._restore_memory_regions ()
159+
141160 # Cgroup kill — catches descendants that escaped the process group
142161 if self ._cgroup_scope is not None :
143162 _cgroup .kill_scope (self ._cgroup_scope )
@@ -172,6 +191,12 @@ def abort(self, timeout: float = 5.0) -> None:
172191 self ._reap ()
173192 self ._exited = True
174193
194+ def _restore_memory_regions (self ) -> None :
195+ """Restore write access to all protected memory regions."""
196+ for region in self ._memory_regions :
197+ _mprotect (region .addr , region .size , region .original_prot )
198+ self ._memory_regions .clear ()
199+
175200 def _reap (self ) -> None :
176201 """Reap the child process (non-blocking, best-effort)."""
177202 try :
@@ -243,11 +268,26 @@ def __enter__(self) -> "BranchContext":
243268 except OSError :
244269 pass # Best-effort
245270
271+ # Protect registered memory regions in the parent
272+ if self ._protected_regions :
273+ for addr , size in self ._protected_regions :
274+ region = MemoryRegion (
275+ addr = addr ,
276+ size = size ,
277+ original_prot = PROT_READ | PROT_WRITE ,
278+ )
279+ _mprotect (addr , size , PROT_READ )
280+ self ._memory_regions .append (region )
281+
246282 return self
247283
248284 def __exit__ (self , exc_type , exc_val , exc_tb ) -> bool :
249285 self .abort ()
250286
287+ # If abort() was a no-op (child already waited/exited), ensure
288+ # regions are still restored.
289+ self ._restore_memory_regions ()
290+
251291 # Clean up private mount dir (best-effort)
252292 try :
253293 os .rmdir (self ._private_dir )
@@ -266,6 +306,7 @@ def create(
266306 close_fds : bool = False ,
267307 limits : ResourceLimits | None = None ,
268308 parent_cgroup : Path | None = None ,
309+ protected_regions : Sequence [tuple [int , int ]] | None = None ,
269310 ) -> Iterator [list ["BranchContext" ]]:
270311 """Create N branch contexts, mirroring branch(BR_CREATE, n_branches=N).
271312
@@ -278,6 +319,8 @@ def create(
278319 close_fds: BR_CLOSE_FDS — close inherited fds in children.
279320 limits: Optional resource limits applied via cgroup v2.
280321 parent_cgroup: Optional parent cgroup for hierarchical nesting.
322+ protected_regions: Optional list of (addr, size) tuples to
323+ mark read-only in the parent after each fork.
281324
282325 Yields:
283326 List of entered BranchContext instances (already forked).
@@ -291,6 +334,7 @@ def create(
291334 ctx = BranchContext (
292335 target , workspace , isolate = isolate , close_fds = close_fds ,
293336 limits = limits , parent_cgroup = parent_cgroup ,
337+ protected_regions = protected_regions ,
294338 )
295339 ctx .__enter__ ()
296340 contexts .append (ctx )
0 commit comments