|
6 | 6 | @software: PyCharm |
7 | 7 | @time: 18-12-25 下午4:58 |
8 | 8 | """ |
9 | | - |
| 9 | +import threading |
10 | 10 | from collections import MutableMapping, MutableSequence |
11 | 11 | from contextlib import contextmanager |
12 | 12 | from typing import Dict, Generator, List, Union |
@@ -90,6 +90,7 @@ def __init__(self, app=None, *, username: str = "root", passwd: str = None, host |
90 | 90 | self.dialect: str = dialect |
91 | 91 | self.msg_zh: str = "" |
92 | 92 | self.scoped_sessions: Dict[str, scoped_session] = {} # 主要保存其他scope session |
| 93 | + self.registry = threading.local() # 当前线程注册bind key |
93 | 94 |
|
94 | 95 | # 这里要用重写的BaseQuery, 根据BaseQuery的规则,Model中的query_class也需要重新指定为子类model, |
95 | 96 | # 但是从Model的初始化看,如果Model的query_class为None的话还是会设置为和Query一致,符合要求 |
@@ -156,8 +157,8 @@ def init_app(self, app, username: str = None, passwd: str = None, host: str = No |
156 | 157 |
|
157 | 158 | @app.teardown_appcontext |
158 | 159 | def _shutdown_other_session(response_or_exc): |
159 | | - for _, session_ in self.scoped_sessions.items(): |
160 | | - session_.remove() |
| 160 | + for bind_key in getattr(self.registry, "bind_keys", set()): |
| 161 | + self.scoped_sessions[bind_key].remove() |
161 | 162 | return response_or_exc |
162 | 163 |
|
163 | 164 | def get_engine(self, app=None, bind=None): |
@@ -233,6 +234,10 @@ def gen_session(self, bind_key: str, session_options: Dict = None) -> Session: |
233 | 234 | session = self.scoped_sessions[bind_key]() |
234 | 235 | session.bind_key = bind_key # 设置bind key |
235 | 236 | session = self.ping_session(session) # 校验重连,保证可用 |
| 237 | + # 加入当前线程bindkey,用于自动关闭处理 |
| 238 | + if hasattr(self.registry, "bind_keys") is False: |
| 239 | + self.registry.bind_keys = set() |
| 240 | + self.registry.bind_keys.add(bind_key) |
236 | 241 |
|
237 | 242 | return session |
238 | 243 |
|
@@ -449,6 +454,7 @@ class CustomBaseQuery(BaseQuery): |
449 | 454 | 目前是改造如果limit传递为0,则返回所有的数据,这样业务代码中就不用更改了 |
450 | 455 | """ |
451 | 456 |
|
| 457 | + # noinspection DuplicatedCode |
452 | 458 | def paginate(self, page: int = 1, per_page: int = 20, max_per_page: int = None, |
453 | 459 | primary_order: bool = True) -> Pagination: |
454 | 460 | """Returns ``per_page`` items from page ``page``. |
|
0 commit comments