3737
3838BACKPRESSURE_WARNING = "EZMSG_DISABLE_BACKPRESSURE_WARNING" not in os .environ
3939BACKPRESSURE_REFRACTORY = 5.0 # sec
40+ ALLOW_LOCAL_ENV = "EZMSG_ALLOW_LOCAL"
41+ FORCE_TCP_ENV = "EZMSG_FORCE_TCP"
42+
43+
44+ def _process_allow_local_default () -> bool :
45+ value = os .environ .get (ALLOW_LOCAL_ENV , "" )
46+ if value == "" :
47+ return True
48+ return value .lower () in ("1" , "true" , "yes" , "on" )
49+
50+
51+ def _process_force_tcp_default () -> bool :
52+ value = os .environ .get (FORCE_TCP_ENV , "" )
53+ if value == "" :
54+ return False
55+ return value .lower () in ("1" , "true" , "yes" , "on" )
56+
57+
58+ def _resolve_force_tcp (force_tcp : bool | None ) -> bool :
59+ if force_tcp is None :
60+ return _process_force_tcp_default ()
61+ return force_tcp
62+
63+
64+ def _resolve_allow_local (force_tcp : bool , allow_local : bool | None ) -> bool :
65+ resolved = _process_allow_local_default () if allow_local is None else allow_local
66+ if force_tcp and resolved :
67+ logger .info ("force_tcp=True disables local delivery for this publisher" )
68+ return False
69+ return resolved
4070
4171
4272# Publisher needs a bit more information about connected channels
@@ -75,6 +105,7 @@ class Publisher:
75105 _msg_id : int
76106 _shm : SHMContext
77107 _force_tcp : bool
108+ _allow_local : bool
78109 _last_backpressure_event : float
79110
80111 _graph_address : AddressType | None
@@ -99,7 +130,8 @@ async def create(
99130 buf_size : int = DEFAULT_SHM_SIZE ,
100131 num_buffers : int = 32 ,
101132 start_paused : bool = False ,
102- force_tcp : bool = False ,
133+ force_tcp : bool | None = None ,
134+ allow_local : bool | None = None ,
103135 ) -> "Publisher" :
104136 """
105137 Create a new Publisher instance and register it with the graph server.
@@ -116,6 +148,16 @@ async def create(
116148 :type port: int | None
117149 :param buf_size: Size of shared memory buffers.
118150 :type buf_size: int
151+ :param force_tcp: Whether to force TCP transport instead of shared memory.
152+ If None, inherit the process default from ``EZMSG_FORCE_TCP`` which
153+ defaults to disabled.
154+ :type force_tcp: bool | None
155+ :param allow_local: Whether to allow the in-process fast path when available.
156+ If None, inherit the process default from ``EZMSG_ALLOW_LOCAL`` which
157+ defaults to enabled. Set to False to bypass local delivery and
158+ characterize same-process SHM or TCP. When ``force_tcp=True``, local
159+ delivery is disabled regardless of this value.
160+ :type allow_local: bool | None
119161 :param kwargs: Additional keyword arguments for Publisher constructor.
120162 :return: Initialized and registered Publisher instance.
121163 :rtype: Publisher
@@ -127,6 +169,8 @@ async def create(
127169 writer .write (Command .PUBLISH .value )
128170 writer .write (encode_str (topic ))
129171
172+ resolved_force_tcp = _resolve_force_tcp (force_tcp )
173+
130174 pub_id = UUID (await read_str (reader ))
131175 pub = cls (
132176 id = pub_id ,
@@ -135,7 +179,8 @@ async def create(
135179 graph_address = graph_address ,
136180 num_buffers = num_buffers ,
137181 start_paused = start_paused ,
138- force_tcp = force_tcp ,
182+ force_tcp = resolved_force_tcp ,
183+ allow_local = allow_local ,
139184 _guard = cls ._SENTINEL ,
140185 )
141186
@@ -189,7 +234,8 @@ def __init__(
189234 graph_address : AddressType | None = None ,
190235 num_buffers : int = 32 ,
191236 start_paused : bool = False ,
192- force_tcp : bool = False ,
237+ force_tcp : bool | None = None ,
238+ allow_local : bool | None = None ,
193239 _guard = None
194240 ) -> None :
195241 """
@@ -207,7 +253,12 @@ def __init__(
207253 :param start_paused: Whether to start in paused state.
208254 :type start_paused: bool
209255 :param force_tcp: Whether to force TCP transport instead of shared memory.
210- :type force_tcp: bool
256+ If None, inherit the process default from ``EZMSG_FORCE_TCP``.
257+ :type force_tcp: bool | None
258+ :param allow_local: Whether to allow the direct in-process fast path when available.
259+ If None, inherit the process default from ``EZMSG_ALLOW_LOCAL``.
260+ When ``force_tcp=True``, local delivery is disabled regardless of this value.
261+ :type allow_local: bool | None
211262 """
212263 if _guard is not self ._SENTINEL :
213264 raise TypeError (
@@ -227,7 +278,8 @@ def __init__(
227278 self ._running .set ()
228279 self ._num_buffers = num_buffers
229280 self ._backpressure = Backpressure (num_buffers )
230- self ._force_tcp = force_tcp
281+ self ._force_tcp = _resolve_force_tcp (force_tcp )
282+ self ._allow_local = _resolve_allow_local (self ._force_tcp , allow_local )
231283 self ._last_backpressure_event = - 1
232284 self ._graph_address = graph_address
233285
@@ -436,22 +488,18 @@ async def broadcast(self, obj: Any) -> None:
436488 self ._last_backpressure_event = time .time ()
437489 await self ._backpressure .wait (buf_idx )
438490
439- # Get local channel and put variable there for local tx
440- self ._local_channel .put_local (self ._msg_id , obj )
491+ if self . _should_use_local_fast_path ():
492+ self ._local_channel .put_local (self ._msg_id , obj )
441493
442- if self ._force_tcp or any (
443- ch .pid != self .pid or not ch .shm_ok for ch in self ._channels .values ()
444- ):
494+ if any (not self ._can_deliver_locally (ch ) for ch in self ._channels .values ()):
445495 with MessageMarshal .serialize (self ._msg_id , obj ) as (
446496 total_size ,
447497 header ,
448498 buffers ,
449499 ):
450500 total_size_bytes = uint64_to_bytes (total_size )
451501
452- if not self ._force_tcp and any (
453- ch .pid != self .pid and ch .shm_ok for ch in self ._channels .values ()
454- ):
502+ if any (self ._can_deliver_via_shm (ch ) for ch in self ._channels .values ()):
455503 if self ._shm .buf_size < total_size :
456504 new_shm = await GraphService (self ._graph_address ).create_shm (
457505 self ._num_buffers , total_size * 2
@@ -475,14 +523,10 @@ async def broadcast(self, obj: Any) -> None:
475523 for channel in self ._channels .values ():
476524 msg : bytes = b""
477525
478- if self .pid == channel . pid and channel . shm_ok :
526+ if self ._can_deliver_locally ( channel ) :
479527 continue # Local transmission handled by channel.put
480528
481- elif (
482- (not self ._force_tcp )
483- and self .pid != channel .pid
484- and channel .shm_ok
485- ):
529+ elif self ._can_deliver_via_shm (channel ):
486530 msg = (
487531 Command .TX_SHM .value
488532 + msg_id_bytes
@@ -509,3 +553,16 @@ async def broadcast(self, obj: Any) -> None:
509553 )
510554
511555 self ._msg_id += 1
556+
557+ def _should_use_local_fast_path (self ) -> bool :
558+ return any (self ._can_deliver_locally (ch ) for ch in self ._channels .values ())
559+
560+ def _can_deliver_locally (self , channel : PubChannelInfo ) -> bool :
561+ return self ._allow_local and self .pid == channel .pid and channel .shm_ok
562+
563+ def _can_deliver_via_shm (self , channel : PubChannelInfo ) -> bool :
564+ return (
565+ (not self ._force_tcp )
566+ and channel .shm_ok
567+ and not self ._can_deliver_locally (channel )
568+ )
0 commit comments