Skip to content

Commit b99283e

Browse files
Merge pull request #236 from ezmsg-org/codex/hotpath-ab-suite
Performance: Quicker/Better Hotpath and A/B Testing
2 parents b283f79 + 1b84fed commit b99283e

10 files changed

Lines changed: 1294 additions & 24 deletions

File tree

scripts/perf_ab.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from ezmsg.util.perf.ab import main
2+
3+
4+
if __name__ == "__main__":
5+
main()

src/ezmsg/core/backendprocess.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ async def setup_state():
299299
buf_size=stream.buf_size,
300300
start_paused=True,
301301
force_tcp=stream.force_tcp,
302+
allow_local=stream.allow_local,
302303
),
303304
loop=loop,
304305
).result()

src/ezmsg/core/pubclient.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,36 @@
3737

3838
BACKPRESSURE_WARNING = "EZMSG_DISABLE_BACKPRESSURE_WARNING" not in os.environ
3939
BACKPRESSURE_REFRACTORY = 5.0 # sec
40+
ALLOW_LOCAL_ENV = "EZMSG_ALLOW_LOCAL"
41+
FORCE_TCP_ENV = "EZMSG_FORCE_TCP"
42+
43+
44+
def _process_allow_local_default() -> bool:
45+
value = os.environ.get(ALLOW_LOCAL_ENV, "")
46+
if value == "":
47+
return True
48+
return value.lower() in ("1", "true", "yes", "on")
49+
50+
51+
def _process_force_tcp_default() -> bool:
52+
value = os.environ.get(FORCE_TCP_ENV, "")
53+
if value == "":
54+
return False
55+
return value.lower() in ("1", "true", "yes", "on")
56+
57+
58+
def _resolve_force_tcp(force_tcp: bool | None) -> bool:
59+
if force_tcp is None:
60+
return _process_force_tcp_default()
61+
return force_tcp
62+
63+
64+
def _resolve_allow_local(force_tcp: bool, allow_local: bool | None) -> bool:
65+
resolved = _process_allow_local_default() if allow_local is None else allow_local
66+
if force_tcp and resolved:
67+
logger.info("force_tcp=True disables local delivery for this publisher")
68+
return False
69+
return resolved
4070

4171

4272
# Publisher needs a bit more information about connected channels
@@ -75,6 +105,7 @@ class Publisher:
75105
_msg_id: int
76106
_shm: SHMContext
77107
_force_tcp: bool
108+
_allow_local: bool
78109
_last_backpressure_event: float
79110

80111
_graph_address: AddressType | None
@@ -99,7 +130,8 @@ async def create(
99130
buf_size: int = DEFAULT_SHM_SIZE,
100131
num_buffers: int = 32,
101132
start_paused: bool = False,
102-
force_tcp: bool = False,
133+
force_tcp: bool | None = None,
134+
allow_local: bool | None = None,
103135
) -> "Publisher":
104136
"""
105137
Create a new Publisher instance and register it with the graph server.
@@ -116,6 +148,16 @@ async def create(
116148
:type port: int | None
117149
:param buf_size: Size of shared memory buffers.
118150
:type buf_size: int
151+
:param force_tcp: Whether to force TCP transport instead of shared memory.
152+
If None, inherit the process default from ``EZMSG_FORCE_TCP`` which
153+
defaults to disabled.
154+
:type force_tcp: bool | None
155+
:param allow_local: Whether to allow the in-process fast path when available.
156+
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL`` which
157+
defaults to enabled. Set to False to bypass local delivery and
158+
characterize same-process SHM or TCP. When ``force_tcp=True``, local
159+
delivery is disabled regardless of this value.
160+
:type allow_local: bool | None
119161
:param kwargs: Additional keyword arguments for Publisher constructor.
120162
:return: Initialized and registered Publisher instance.
121163
:rtype: Publisher
@@ -127,6 +169,8 @@ async def create(
127169
writer.write(Command.PUBLISH.value)
128170
writer.write(encode_str(topic))
129171

172+
resolved_force_tcp = _resolve_force_tcp(force_tcp)
173+
130174
pub_id = UUID(await read_str(reader))
131175
pub = cls(
132176
id=pub_id,
@@ -135,7 +179,8 @@ async def create(
135179
graph_address=graph_address,
136180
num_buffers=num_buffers,
137181
start_paused=start_paused,
138-
force_tcp=force_tcp,
182+
force_tcp=resolved_force_tcp,
183+
allow_local=allow_local,
139184
_guard=cls._SENTINEL,
140185
)
141186

@@ -189,7 +234,8 @@ def __init__(
189234
graph_address: AddressType | None = None,
190235
num_buffers: int = 32,
191236
start_paused: bool = False,
192-
force_tcp: bool = False,
237+
force_tcp: bool | None = None,
238+
allow_local: bool | None = None,
193239
_guard = None
194240
) -> None:
195241
"""
@@ -207,7 +253,12 @@ def __init__(
207253
:param start_paused: Whether to start in paused state.
208254
:type start_paused: bool
209255
:param force_tcp: Whether to force TCP transport instead of shared memory.
210-
:type force_tcp: bool
256+
If None, inherit the process default from ``EZMSG_FORCE_TCP``.
257+
:type force_tcp: bool | None
258+
:param allow_local: Whether to allow the direct in-process fast path when available.
259+
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``.
260+
When ``force_tcp=True``, local delivery is disabled regardless of this value.
261+
:type allow_local: bool | None
211262
"""
212263
if _guard is not self._SENTINEL:
213264
raise TypeError(
@@ -227,7 +278,8 @@ def __init__(
227278
self._running.set()
228279
self._num_buffers = num_buffers
229280
self._backpressure = Backpressure(num_buffers)
230-
self._force_tcp = force_tcp
281+
self._force_tcp = _resolve_force_tcp(force_tcp)
282+
self._allow_local = _resolve_allow_local(self._force_tcp, allow_local)
231283
self._last_backpressure_event = -1
232284
self._graph_address = graph_address
233285

@@ -436,22 +488,18 @@ async def broadcast(self, obj: Any) -> None:
436488
self._last_backpressure_event = time.time()
437489
await self._backpressure.wait(buf_idx)
438490

439-
# Get local channel and put variable there for local tx
440-
self._local_channel.put_local(self._msg_id, obj)
491+
if self._should_use_local_fast_path():
492+
self._local_channel.put_local(self._msg_id, obj)
441493

442-
if self._force_tcp or any(
443-
ch.pid != self.pid or not ch.shm_ok for ch in self._channels.values()
444-
):
494+
if any(not self._can_deliver_locally(ch) for ch in self._channels.values()):
445495
with MessageMarshal.serialize(self._msg_id, obj) as (
446496
total_size,
447497
header,
448498
buffers,
449499
):
450500
total_size_bytes = uint64_to_bytes(total_size)
451501

452-
if not self._force_tcp and any(
453-
ch.pid != self.pid and ch.shm_ok for ch in self._channels.values()
454-
):
502+
if any(self._can_deliver_via_shm(ch) for ch in self._channels.values()):
455503
if self._shm.buf_size < total_size:
456504
new_shm = await GraphService(self._graph_address).create_shm(
457505
self._num_buffers, total_size * 2
@@ -475,14 +523,10 @@ async def broadcast(self, obj: Any) -> None:
475523
for channel in self._channels.values():
476524
msg: bytes = b""
477525

478-
if self.pid == channel.pid and channel.shm_ok:
526+
if self._can_deliver_locally(channel):
479527
continue # Local transmission handled by channel.put
480528

481-
elif (
482-
(not self._force_tcp)
483-
and self.pid != channel.pid
484-
and channel.shm_ok
485-
):
529+
elif self._can_deliver_via_shm(channel):
486530
msg = (
487531
Command.TX_SHM.value
488532
+ msg_id_bytes
@@ -509,3 +553,16 @@ async def broadcast(self, obj: Any) -> None:
509553
)
510554

511555
self._msg_id += 1
556+
557+
def _should_use_local_fast_path(self) -> bool:
558+
return any(self._can_deliver_locally(ch) for ch in self._channels.values())
559+
560+
def _can_deliver_locally(self, channel: PubChannelInfo) -> bool:
561+
return self._allow_local and self.pid == channel.pid and channel.shm_ok
562+
563+
def _can_deliver_via_shm(self, channel: PubChannelInfo) -> bool:
564+
return (
565+
(not self._force_tcp)
566+
and channel.shm_ok
567+
and not self._can_deliver_locally(channel)
568+
)

src/ezmsg/core/stream.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,20 @@ class OutputStream(Stream):
228228
:type num_buffers: int
229229
:param buf_size: Size of each message buffer in bytes
230230
:type buf_size: int
231-
:param force_tcp: Whether to force TCP transport instead of shared memory
232-
:type force_tcp: bool
231+
:param force_tcp: Whether to force TCP transport instead of shared memory.
232+
If None, inherit the process default from ``EZMSG_FORCE_TCP``.
233+
:type force_tcp: bool | None
234+
:param allow_local: Whether to allow the in-process fast path when available.
235+
If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``.
236+
:type allow_local: bool | None
233237
"""
234238

235239
host: str | None
236240
port: int | None
237241
num_buffers: int
238242
buf_size: int
239-
force_tcp: bool
243+
force_tcp: bool | None
244+
allow_local: bool | None
240245

241246
def __init__(
242247
self,
@@ -245,15 +250,20 @@ def __init__(
245250
port: int | None = None,
246251
num_buffers: int = 32,
247252
buf_size: int = DEFAULT_SHM_SIZE,
248-
force_tcp: bool = False,
253+
force_tcp: bool | None = None,
254+
allow_local: bool | None = None,
249255
) -> None:
250256
super().__init__(msg_type)
251257
self.host = host
252258
self.port = port
253259
self.num_buffers = num_buffers
254260
self.buf_size = buf_size
255261
self.force_tcp = force_tcp
262+
self.allow_local = allow_local
256263

257264
def __repr__(self) -> str:
258265
preamble = f"Output{super().__repr__()}"
259-
return f"{preamble}({self.num_buffers=}, {self.force_tcp=})"
266+
return (
267+
f"{preamble}({self.num_buffers=}, {self.force_tcp=}, "
268+
f"{self.allow_local=})"
269+
)

0 commit comments

Comments
 (0)