Skip to content

Commit f3a8c3d

Browse files
committed
Add mprotect-based parent memory protection after branching
Signed-off-by: Cong Wang <cwang@multikernel.io>
1 parent 5913eae commit f3a8c3d

5 files changed

Lines changed: 213 additions & 9 deletions

File tree

src/branching/exceptions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,9 @@ class MemoryBranchError(BranchingError):
7272
"""Memory branching operation failed."""
7373

7474
pass
75+
76+
77+
class MemoryProtectError(MemoryBranchError):
78+
"""mprotect() failed on a registered memory region."""
79+
80+
pass

src/branching/memory/__init__.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,25 @@
1-
from .base import MemoryBackend, StubMemoryBackend
1+
from .base import MemoryBackend, StubMemoryBackend, MprotectMemoryBackend
2+
from ._mprotect import (
3+
MemoryRegion,
4+
mprotect,
5+
protect_regions,
6+
restore_regions,
7+
PROT_NONE,
8+
PROT_READ,
9+
PROT_WRITE,
10+
PROT_EXEC,
11+
)
212

3-
__all__ = ["MemoryBackend", "StubMemoryBackend"]
13+
__all__ = [
14+
"MemoryBackend",
15+
"StubMemoryBackend",
16+
"MprotectMemoryBackend",
17+
"MemoryRegion",
18+
"mprotect",
19+
"protect_regions",
20+
"restore_regions",
21+
"PROT_NONE",
22+
"PROT_READ",
23+
"PROT_WRITE",
24+
"PROT_EXEC",
25+
]

src/branching/memory/_mprotect.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""ctypes bindings for mprotect(2).
3+
4+
Provides memory protection manipulation for enforcing read-only
5+
invariants on parent memory regions after branching.
6+
"""
7+
8+
import ctypes
9+
import ctypes.util
10+
import os
11+
from dataclasses import dataclass, field
12+
13+
from ..exceptions import MemoryProtectError
14+
15+
_libc = ctypes.CDLL(ctypes.util.find_library("c"), use_errno=True)
16+
17+
# Protection flags for mprotect(2)
18+
PROT_NONE = 0x0
19+
PROT_READ = 0x1
20+
PROT_WRITE = 0x2
21+
PROT_EXEC = 0x4
22+
23+
24+
def mprotect(addr: int, size: int, prot: int) -> None:
25+
"""Call mprotect(2) to change memory protection.
26+
27+
Args:
28+
addr: Page-aligned start address of the region.
29+
size: Size of the region in bytes.
30+
prot: New protection flags (PROT_READ, PROT_WRITE, etc.).
31+
32+
Raises:
33+
MemoryProtectError: If the syscall fails.
34+
"""
35+
ret = _libc.mprotect(ctypes.c_void_p(addr), ctypes.c_size_t(size), ctypes.c_int(prot))
36+
if ret != 0:
37+
err = ctypes.get_errno()
38+
raise MemoryProtectError(
39+
f"mprotect(0x{addr:x}, {size}, 0x{prot:x}): {os.strerror(err)}"
40+
)
41+
42+
43+
@dataclass
44+
class MemoryRegion:
45+
"""A memory region tracked for protection changes.
46+
47+
Attributes:
48+
addr: Page-aligned start address.
49+
size: Size of the region in bytes.
50+
original_prot: Protection flags before we modified them.
51+
"""
52+
addr: int
53+
size: int
54+
original_prot: int = field(default=PROT_READ | PROT_WRITE)
55+
56+
57+
def protect_regions(regions: list[MemoryRegion], prot: int = PROT_READ) -> None:
58+
"""Apply memory protection to a list of regions.
59+
60+
Each region's ``original_prot`` is assumed to already be set before
61+
calling this function.
62+
63+
Args:
64+
regions: Regions to protect.
65+
prot: Protection flags to apply.
66+
67+
Raises:
68+
MemoryProtectError: If any mprotect call fails.
69+
"""
70+
for region in regions:
71+
mprotect(region.addr, region.size, prot)
72+
73+
74+
def restore_regions(regions: list[MemoryRegion]) -> None:
75+
"""Restore each region to its original protection.
76+
77+
Args:
78+
regions: Regions whose original_prot should be restored.
79+
80+
Raises:
81+
MemoryProtectError: If any mprotect call fails.
82+
"""
83+
for region in regions:
84+
mprotect(region.addr, region.size, region.original_prot)

src/branching/memory/base.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
# SPDX-License-Identifier: Apache-2.0
2-
"""Memory branching stub for future kernel support.
2+
"""Memory branching backends.
33
4-
This module defines the MemoryBackend ABC and a stub implementation.
5-
Real memory branching requires kernel support via the proposed branch()
6-
syscall with BR_MEMORY flag.
4+
This module defines the MemoryBackend ABC, a stub implementation, and
5+
an mprotect-based implementation that enforces read-only parent memory
6+
after branching.
77
"""
88

99
from abc import ABC, abstractmethod
1010
from typing import Any
1111

12+
from ._mprotect import (
13+
MemoryRegion,
14+
PROT_READ,
15+
PROT_WRITE,
16+
mprotect,
17+
)
18+
1219

1320
class MemoryBackend(ABC):
1421
"""Abstract base class for memory branching.
@@ -71,3 +78,44 @@ def commit(self, handle: Any) -> None:
7178
raise NotImplementedError(
7279
"Memory branching requires kernel support via the branch() syscall."
7380
)
81+
82+
83+
class MprotectMemoryBackend(MemoryBackend):
84+
"""Memory backend using mprotect(2) to enforce read-only parent state.
85+
86+
After ``snapshot()``, the region is marked read-only so the parent
87+
cannot mutate branched state. ``restore()`` and ``commit()`` both
88+
re-enable write access (the parent regains control after branching).
89+
"""
90+
91+
def snapshot(self, addr: int, size: int) -> MemoryRegion:
92+
"""Mark a memory region as read-only and return a handle.
93+
94+
Args:
95+
addr: Page-aligned start address of the region.
96+
size: Size of the region in bytes.
97+
98+
Returns:
99+
A MemoryRegion handle that can be passed to restore/commit.
100+
"""
101+
region = MemoryRegion(addr=addr, size=size, original_prot=PROT_READ | PROT_WRITE)
102+
mprotect(addr, size, PROT_READ)
103+
return region
104+
105+
def restore(self, handle: Any) -> None:
106+
"""Restore write access to a previously snapshotted region.
107+
108+
Args:
109+
handle: MemoryRegion from snapshot().
110+
"""
111+
region: MemoryRegion = handle
112+
mprotect(region.addr, region.size, region.original_prot)
113+
114+
def commit(self, handle: Any) -> None:
115+
"""Commit the snapshot — parent regains write access.
116+
117+
Args:
118+
handle: MemoryRegion from snapshot().
119+
"""
120+
region: MemoryRegion = handle
121+
mprotect(region.addr, region.size, region.original_prot)

src/branching/process/context.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,13 @@
1616
from pathlib import Path
1717
from typing import Callable, Iterator, Optional, Sequence, TYPE_CHECKING
1818

19-
from ..exceptions import ForkError, ProcessBranchError
19+
from ..exceptions import ForkError, MemoryProtectError, ProcessBranchError
20+
from ..memory._mprotect import (
21+
MemoryRegion,
22+
PROT_READ,
23+
PROT_WRITE,
24+
mprotect as _mprotect,
25+
)
2026
from . import _cgroup
2127
from ._namespace import setup_user_ns, bind_mount
2228

@@ -43,6 +49,7 @@ def __init__(
4349
close_fds: bool = False,
4450
limits: ResourceLimits | None = None,
4551
parent_cgroup: Path | None = None,
52+
protected_regions: Sequence[tuple[int, int]] | None = None,
4653
):
4754
"""
4855
Args:
@@ -56,16 +63,21 @@ def __init__(
5663
parent_cgroup: Optional parent cgroup directory. When given,
5764
the child's scope is created under this directory instead
5865
of the process's own cgroup, enabling hierarchical nesting.
66+
protected_regions: Optional list of (addr, size) tuples.
67+
After fork, these regions are marked read-only in the
68+
parent via mprotect(2) to enforce the branch invariant.
5969
"""
6070
self._target = target
6171
self._workspace = workspace
6272
self._isolate = isolate
6373
self._close_fds = close_fds
6474
self._limits = limits
6575
self._parent_cgroup = parent_cgroup
76+
self._protected_regions = protected_regions
6677
self._pid: Optional[int] = None
6778
self._exited = False
6879
self._cgroup_scope: Optional[Path] = None
80+
self._memory_regions: list[MemoryRegion] = []
6981

7082
@property
7183
def pid(self) -> int:
@@ -91,6 +103,9 @@ def alive(self) -> bool:
91103
def wait(self, timeout: Optional[float] = None) -> None:
92104
"""Wait for the child process to exit.
93105
106+
On completion (success or failure), restores write access to any
107+
protected memory regions so the parent can proceed.
108+
94109
Args:
95110
timeout: Maximum seconds to wait (None = wait forever).
96111
@@ -99,6 +114,7 @@ def wait(self, timeout: Optional[float] = None) -> None:
99114
TimeoutError: If the child doesn't exit within timeout.
100115
"""
101116
exit_code = self._wait_raw(timeout)
117+
self._restore_memory_regions()
102118
if exit_code != 0:
103119
raise ProcessBranchError(
104120
f"Child {self._pid} exited with status {exit_code}"
@@ -132,12 +148,15 @@ def _wait_raw(self, timeout: Optional[float] = None) -> int:
132148
def abort(self, timeout: float = 5.0) -> None:
133149
"""Abort the child and all its descendants.
134150
135-
Uses both cgroup kill (catches escapees) and killpg (POSIX standard).
136-
Escalates SIGTERM -> SIGKILL after timeout.
151+
Restores write access to protected memory regions, then kills the
152+
child. Uses both cgroup kill (catches escapees) and killpg (POSIX
153+
standard). Escalates SIGTERM -> SIGKILL after timeout.
137154
"""
138155
if self._pid is None or self._exited:
139156
return
140157

158+
self._restore_memory_regions()
159+
141160
# Cgroup kill — catches descendants that escaped the process group
142161
if self._cgroup_scope is not None:
143162
_cgroup.kill_scope(self._cgroup_scope)
@@ -172,6 +191,12 @@ def abort(self, timeout: float = 5.0) -> None:
172191
self._reap()
173192
self._exited = True
174193

194+
def _restore_memory_regions(self) -> None:
195+
"""Restore write access to all protected memory regions."""
196+
for region in self._memory_regions:
197+
_mprotect(region.addr, region.size, region.original_prot)
198+
self._memory_regions.clear()
199+
175200
def _reap(self) -> None:
176201
"""Reap the child process (non-blocking, best-effort)."""
177202
try:
@@ -243,11 +268,26 @@ def __enter__(self) -> "BranchContext":
243268
except OSError:
244269
pass # Best-effort
245270

271+
# Protect registered memory regions in the parent
272+
if self._protected_regions:
273+
for addr, size in self._protected_regions:
274+
region = MemoryRegion(
275+
addr=addr,
276+
size=size,
277+
original_prot=PROT_READ | PROT_WRITE,
278+
)
279+
_mprotect(addr, size, PROT_READ)
280+
self._memory_regions.append(region)
281+
246282
return self
247283

248284
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
249285
self.abort()
250286

287+
# If abort() was a no-op (child already waited/exited), ensure
288+
# regions are still restored.
289+
self._restore_memory_regions()
290+
251291
# Clean up private mount dir (best-effort)
252292
try:
253293
os.rmdir(self._private_dir)
@@ -266,6 +306,7 @@ def create(
266306
close_fds: bool = False,
267307
limits: ResourceLimits | None = None,
268308
parent_cgroup: Path | None = None,
309+
protected_regions: Sequence[tuple[int, int]] | None = None,
269310
) -> Iterator[list["BranchContext"]]:
270311
"""Create N branch contexts, mirroring branch(BR_CREATE, n_branches=N).
271312
@@ -278,6 +319,8 @@ def create(
278319
close_fds: BR_CLOSE_FDS — close inherited fds in children.
279320
limits: Optional resource limits applied via cgroup v2.
280321
parent_cgroup: Optional parent cgroup for hierarchical nesting.
322+
protected_regions: Optional list of (addr, size) tuples to
323+
mark read-only in the parent after each fork.
281324
282325
Yields:
283326
List of entered BranchContext instances (already forked).
@@ -291,6 +334,7 @@ def create(
291334
ctx = BranchContext(
292335
target, workspace, isolate=isolate, close_fds=close_fds,
293336
limits=limits, parent_cgroup=parent_cgroup,
337+
protected_regions=protected_regions,
294338
)
295339
ctx.__enter__()
296340
contexts.append(ctx)

0 commit comments

Comments
 (0)