Skip to content

Commit 4543ca6

Browse files
author
陈云亮
committed
feat(raw_cursor): 支持位置占位符和查询缓存机制
- 实现 GaussDBRawQuery 支持只使用位置占位符 ($1, $2, ...) ,禁止命名占位符 - 添加查询缓存功能,缓存已解析的查询字节串,提高执行效率 - 在执行时检查查询是否包含命名占位符,若有则抛出明确的 ProgrammingError - 参数序列化时严格要求参数为序列类型,若传入字典则抛出 TypeError,提示使用普通 Cursor - 提供 clear_cache 方法支持清理查询缓存 - 补充单元测试,验证命名参数使用异常抛出、查询缓存以及缓存清理功能正常工作
1 parent e071588 commit 4543ca6

2 files changed

Lines changed: 99 additions & 2 deletions

File tree

gaussdb/gaussdb/raw_cursor.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing import TYPE_CHECKING
1010

11+
from . import errors as e
1112
from .abc import ConnectionType, Params, Query
1213
from .sql import Composable
1314
from .rows import Row
@@ -26,6 +27,19 @@
2627

2728

2829
class GaussDBRawQuery(GaussDBQuery):
30+
"""
31+
GaussDB raw query class.
32+
33+
Only supports positional placeholders ($1, $2, ...), not named placeholders.
34+
"""
35+
36+
# Query cache size
37+
_CACHE_SIZE = 128
38+
39+
def __init__(self, *args, **kwargs):
40+
super().__init__(*args, **kwargs)
41+
self._query_cache: dict[bytes, bytes] = {}
42+
2943
def convert(self, query: Query, vars: Params | None) -> None:
3044
if isinstance(query, str):
3145
bquery = query.encode(self._encoding)
@@ -34,14 +48,43 @@ def convert(self, query: Query, vars: Params | None) -> None:
3448
else:
3549
bquery = query
3650

37-
self.query = bquery
51+
# Try to get from cache
52+
if bquery in self._query_cache:
53+
self.query = self._query_cache[bquery]
54+
else:
55+
# Validate query doesn't contain named placeholders
56+
if b"%(" in bquery:
57+
raise e.ProgrammingError(
58+
"RawCursor does not support named placeholders (%(name)s). "
59+
"Use positional placeholders ($1, $2, ...) instead."
60+
)
61+
62+
self.query = bquery
63+
64+
# Cache result
65+
if len(self._query_cache) < self._CACHE_SIZE:
66+
self._query_cache[bquery] = bquery
67+
3868
self._want_formats = self._order = None
3969
self.dump(vars)
4070

4171
def dump(self, vars: Params | None) -> None:
72+
"""
73+
Serialize parameters.
74+
75+
Args:
76+
vars: Parameter sequence (must be sequence, not dict)
77+
78+
Raises:
79+
TypeError: If parameters are not a sequence
80+
"""
4281
if vars is not None:
4382
if not GaussDBQuery.is_params_sequence(vars):
44-
raise TypeError("raw queries require a sequence of parameters")
83+
raise TypeError(
84+
"RawCursor requires a sequence of parameters (tuple or list), "
85+
f"got {type(vars).__name__}. "
86+
"For named parameters, use regular Cursor instead."
87+
)
4588
self._want_formats = [PyFormat.AUTO] * len(vars)
4689

4790
self.params = self._tx.dump_sequence(vars, self._want_formats)
@@ -52,6 +95,10 @@ def dump(self, vars: Params | None) -> None:
5295
self.types = ()
5396
self.formats = None
5497

98+
def clear_cache(self) -> None:
99+
"""Clear query cache."""
100+
self._query_cache.clear()
101+
55102

56103
class RawCursorMixin(BaseCursor[ConnectionType, Row]):
57104
_query_cls = GaussDBRawQuery

tests/test_cursor_raw.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,28 @@ def test_sequence_only(conn):
3636
cur.execute("select 1", {})
3737

3838

39+
def test_named_params_error(conn):
40+
"""Test named parameter error message."""
41+
cur = conn.cursor()
42+
43+
# Should clearly indicate named placeholders are not supported
44+
with pytest.raises((TypeError, e.ProgrammingError)) as excinfo:
45+
cur.execute("select %(name)s", {"name": 1})
46+
47+
error_msg = str(excinfo.value).lower()
48+
assert "named" in error_msg or "sequence" in error_msg
49+
50+
51+
def test_dict_params_error(conn):
52+
"""Test dict parameter error."""
53+
cur = conn.cursor()
54+
55+
with pytest.raises(TypeError) as excinfo:
56+
cur.execute("select $1", {"a": 1})
57+
58+
assert "sequence" in str(excinfo.value).lower()
59+
60+
3961
def test_execute_many_results_param(conn):
4062
cur = conn.cursor()
4163
# GaussDB raises SyntaxError, CRDB raises InvalidPreparedStatementDefinition
@@ -115,3 +137,31 @@ def work():
115137
gc.collect()
116138
n.append(gc.count())
117139
assert n[0] == n[1] == n[2], f"objects leaked: {n[1] - n[0]}, {n[2] - n[1]}"
140+
141+
142+
def test_query_cache(conn):
143+
"""Test query cache."""
144+
cur = conn.cursor()
145+
query = "select $1::int"
146+
147+
# Execute same query multiple times
148+
for i in range(10):
149+
cur.execute(query, (i,))
150+
assert cur.fetchone()[0] == i
151+
152+
# Verify cache is working (no exception is good)
153+
154+
155+
def test_clear_cache(conn):
156+
"""Test clearing cache."""
157+
cur = conn.cursor()
158+
query = "select $1::int"
159+
160+
cur.execute(query, (1,))
161+
162+
# Clearing cache should not affect subsequent queries
163+
if hasattr(cur._query, "clear_cache"):
164+
cur._query.clear_cache()
165+
166+
cur.execute(query, (2,))
167+
assert cur.fetchone()[0] == 2

0 commit comments

Comments
 (0)