Skip to content

Commit 1858206

Browse files
authored
Merge pull request #1 from not-empty/feature/better-connections
Feature/better connections
2 parents 04defd1 + 416ddb4 commit 1858206

6 files changed

Lines changed: 162 additions & 57 deletions

File tree

src/omniq/_ops.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import redis
33

44
from dataclasses import dataclass
5-
from typing import Optional, Any, List
5+
from typing import Optional, Any, List, ClassVar
66
from threading import Lock
77

88
from .clock import now_ms
@@ -14,7 +14,7 @@
1414

1515
@dataclass
1616
class OmniqOps:
17-
_script_lock = Lock()
17+
_script_lock: ClassVar[Lock] = Lock()
1818
r: RedisLike
1919
scripts: OmniqScripts
2020

@@ -29,7 +29,11 @@ def _evalsha_with_noscript_fallback(
2929
return self.r.evalsha(sha, numkeys, *keys_and_args)
3030
except redis.exceptions.NoScriptError:
3131
with self._script_lock:
32-
return self.r.eval(src, numkeys, *keys_and_args)
32+
try:
33+
return self.r.evalsha(sha, numkeys, *keys_and_args)
34+
except redis.exceptions.NoScriptError:
35+
new_sha = self.r.script_load(src)
36+
return self.r.evalsha(new_sha, numkeys, *keys_and_args)
3337

3438
def publish(
3539
self,
@@ -500,7 +504,6 @@ def child_ack(self, *, key: str, child_id: str) -> int:
500504
except Exception:
501505
return -1
502506

503-
504507
@staticmethod
505508
def paused_backoff_s(poll_interval_s: float) -> float:
506509
return max(0.25, float(poll_interval_s) * 10.0)

src/omniq/client.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,19 @@
77
from .types import ReserveResult, AckFailResult
88
from .helper import queue_base
99

10+
def _safe_close_redis(r: Any) -> None:
11+
if r is None:
12+
return
13+
try:
14+
r.close()
15+
return
16+
except Exception:
17+
pass
18+
try:
19+
r.connection_pool.disconnect()
20+
except Exception:
21+
pass
22+
1023
@dataclass
1124
class OmniqClient:
1225
_ops: OmniqOps
@@ -23,7 +36,10 @@ def __init__(
2336
password: Optional[str] = None,
2437
ssl: bool = False,
2538
scripts_dir: Optional[str] = None,
39+
client_name: Optional[str] = None,
2640
):
41+
self._owns_redis = redis is None
42+
2743
if redis is not None:
2844
r = redis
2945
else:
@@ -38,13 +54,31 @@ def __init__(
3854
ssl=ssl,
3955
)
4056
)
57+
58+
if client_name:
59+
try:
60+
r.client_setname(str(client_name))
61+
except Exception:
62+
pass
4163

4264
if scripts_dir is None:
4365
scripts_dir = default_scripts_dir()
4466
scripts = load_scripts(r, scripts_dir)
4567

4668
self._ops = OmniqOps(r=r, scripts=scripts)
4769

70+
def close(self) -> None:
71+
if not getattr(self, "_owns_redis", False):
72+
return
73+
r = getattr(self._ops, "r", None)
74+
_safe_close_redis(r)
75+
76+
def __enter__(self) -> "OmniqClient":
77+
return self
78+
79+
def __exit__(self, exc_type, exc, tb) -> None:
80+
self.close()
81+
4882
@staticmethod
4983
def queue_base(queue_name: str) -> str:
5084
return queue_base(queue_name)

src/omniq/consumer.py

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,37 @@ def start_heartbeater(
3131
stop_evt = threading.Event()
3232
flags: Dict[str, bool] = {"lost": False}
3333

34+
def _lost(msg: str) -> bool:
35+
msg_u = (msg or "").upper()
36+
return ("NOT_ACTIVE" in msg_u) or ("TOKEN_MISMATCH" in msg_u)
37+
3438
def hb_loop():
3539
try:
3640
client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token)
3741
except Exception as e:
42+
if stop_evt.is_set():
43+
return
3844
msg = str(e)
39-
if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg:
45+
if _lost(msg):
4046
flags["lost"] = True
4147
stop_evt.set()
4248
return
49+
time.sleep(min(0.2, max(0.01, float(interval_s))))
4350

44-
while not stop_evt.wait(interval_s):
51+
while True:
52+
if stop_evt.wait(interval_s):
53+
return
4554
try:
4655
client.heartbeat(queue=queue, job_id=job_id, lease_token=lease_token)
4756
except Exception as e:
57+
if stop_evt.is_set():
58+
return
4859
msg = str(e)
49-
if "NOT_ACTIVE" in msg or "TOKEN_MISMATCH" in msg:
60+
if _lost(msg):
5061
flags["lost"] = True
5162
stop_evt.set()
5263
return
64+
time.sleep(min(0.2, max(0.01, float(interval_s))))
5365

5466
t = threading.Thread(target=hb_loop, daemon=True)
5567
t.start()
@@ -93,37 +105,35 @@ def consume(
93105

94106
ctrl = StopController(stop=False, sigint_count=0)
95107

96-
if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
97-
def on_sigterm(signum, _frame):
98-
ctrl.stop = True
99-
if verbose:
100-
_safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}")
108+
prev_sigterm = None
109+
prev_sigint = None
101110

102-
signal.signal(signal.SIGTERM, on_sigterm)
111+
try:
112+
if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
113+
def on_sigterm(signum, _frame):
114+
ctrl.stop = True
115+
if verbose:
116+
_safe_log(logger, f"[consume] SIGTERM received; stopping... queue={queue}")
103117

104-
if drain:
105-
prev = signal.getsignal(signal.SIGINT)
118+
prev_sigterm = signal.getsignal(signal.SIGTERM)
119+
signal.signal(signal.SIGTERM, on_sigterm)
106120

107-
def on_sigint(signum, frame):
108-
ctrl.sigint_count += 1
109-
if ctrl.sigint_count >= 2:
110-
if verbose:
111-
_safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}")
112-
try:
113-
signal.signal(signal.SIGINT, prev if prev else signal.SIG_DFL)
114-
except Exception:
115-
signal.signal(signal.SIGINT, signal.SIG_DFL)
116-
raise KeyboardInterrupt
121+
if drain:
122+
prev_sigint = signal.getsignal(signal.SIGINT)
117123

118-
ctrl.stop = True
119-
if verbose:
120-
_safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}")
124+
def on_sigint(signum, frame):
125+
ctrl.sigint_count += 1
126+
if ctrl.sigint_count >= 2:
127+
if verbose:
128+
_safe_log(logger, f"[consume] SIGINT x2; hard exit now. queue={queue}")
129+
raise KeyboardInterrupt
121130

122-
signal.signal(signal.SIGINT, on_sigint)
123-
else:
124-
pass
131+
ctrl.stop = True
132+
if verbose:
133+
_safe_log(logger, f"[consume] Ctrl+C received; draining current job then exiting. queue={queue}")
134+
135+
signal.signal(signal.SIGINT, on_sigint)
125136

126-
try:
127137
while True:
128138
if ctrl.stop:
129139
if verbose:
@@ -214,8 +224,6 @@ def on_sigint(signum, frame):
214224
try:
215225
handler(ctx)
216226

217-
hb.stop_evt.set()
218-
219227
if not hb.flags.get("lost", False):
220228
try:
221229
client.ack_success(queue=queue, job_id=res.job_id, lease_token=res.lease_token)
@@ -228,12 +236,9 @@ def on_sigint(signum, frame):
228236
except KeyboardInterrupt:
229237
if verbose:
230238
_safe_log(logger, f"[consume] KeyboardInterrupt; exiting now. queue={queue}")
231-
hb.stop_evt.set()
232239
return
233240

234241
except Exception as e:
235-
hb.stop_evt.set()
236-
237242
if not hb.flags.get("lost", False):
238243
try:
239244
err = f"{type(e).__name__}: {e}"
@@ -255,7 +260,12 @@ def on_sigint(signum, frame):
255260

256261
finally:
257262
try:
258-
hb.thread.join(timeout=0.1)
263+
hb.stop_evt.set()
264+
except Exception:
265+
pass
266+
try:
267+
join_timeout = max(0.2, min(2.0, hb_s * 1.5))
268+
hb.thread.join(timeout=join_timeout)
259269
except Exception:
260270
pass
261271

@@ -268,3 +278,21 @@ def on_sigint(signum, frame):
268278
if verbose:
269279
_safe_log(logger, f"[consume] KeyboardInterrupt (outer); exiting now. queue={queue}")
270280
return
281+
282+
finally:
283+
if stop_on_ctrl_c and threading.current_thread() is threading.main_thread():
284+
try:
285+
if prev_sigterm is not None:
286+
signal.signal(signal.SIGTERM, prev_sigterm)
287+
except Exception:
288+
pass
289+
try:
290+
if prev_sigint is not None:
291+
signal.signal(signal.SIGINT, prev_sigint)
292+
except Exception:
293+
pass
294+
295+
try:
296+
client.close()
297+
except Exception:
298+
pass

src/omniq/exec.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from .client import OmniqClient
55

6-
76
@dataclass(frozen=True)
87
class Exec:
98
client: OmniqClient

src/omniq/scripts.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
from dataclasses import dataclass
3+
from threading import Lock
34
from typing import Protocol
45

56
class ScriptLoader(Protocol):
@@ -32,15 +33,23 @@ def default_scripts_dir() -> str:
3233
here = os.path.dirname(__file__)
3334
return os.path.join(here, "core", "scripts")
3435

36+
_scripts_cache: dict[str, OmniqScripts] = {}
37+
_scripts_cache_lock = Lock()
38+
3539
def load_scripts(r: ScriptLoader, scripts_dir: str) -> OmniqScripts:
40+
with _scripts_cache_lock:
41+
cached = _scripts_cache.get(scripts_dir)
42+
if cached is not None:
43+
return cached
44+
3645
def load_one(name: str) -> ScriptDef:
3746
path = os.path.join(scripts_dir, name)
3847
with open(path, "r", encoding="utf-8") as f:
3948
src = f.read()
4049
sha = r.script_load(src)
4150
return ScriptDef(sha=sha, src=src)
4251

43-
return OmniqScripts(
52+
scripts = OmniqScripts(
4453
enqueue=load_one("enqueue.lua"),
4554
reserve=load_one("reserve.lua"),
4655
ack_success=load_one("ack_success.lua"),
@@ -57,3 +66,8 @@ def load_one(name: str) -> ScriptDef:
5766
childs_init=load_one("childs_init.lua"),
5867
child_ack=load_one("child_ack.lua"),
5968
)
69+
70+
with _scripts_cache_lock:
71+
_scripts_cache[scripts_dir] = scripts
72+
73+
return scripts

0 commit comments

Comments
 (0)