Skip to content

Commit 7226a10

Browse files
author
陈云亮
committed
feat(gaussdb): 支持 GaussDB OID 别名兼容及运行时查询
- 新增 GAUSSDB_OID_ALIASES 字典,实现 Postgres OID 到 GaussDB 多别名映射 - 添加 is_compatible_oid 方法,实现对 OID 兼容性的判断支持别名 - 实现 get_oid_name 方法,根据 OID 获取对应的类型名称字符串 - 在 TypeInfo 中增加 fetch_runtime_oid,支持运行时查询类型 OID,兼容同步和异步连接 - 在 TypeInfo 中添加 get_compatible_oids,用于获取基础 OID 及其别名列表 - 编写测试用例,覆盖 OID 兼容性判断、OID 名称获取及运行时 OID 查询 - 测试中增加 type_code 兼容性校验,确保 GaussDB 返回的 OID 合法有效
1 parent 2857bd8 commit 7226a10

4 files changed

Lines changed: 169 additions & 0 deletions

File tree

gaussdb/gaussdb/_oids.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,55 @@
122122
YEAR_OID = 1038
123123

124124
# autogenerated: end
125+
126+
127+
# =====================================================
128+
# GaussDB OID 别名映射(PostgreSQL OID -> GaussDB 等效 OID 列表)
129+
# =====================================================
130+
131+
GAUSSDB_OID_ALIASES: dict[int, list[int]] = {
132+
# date 类型可能映射到多个 OID
133+
DATE_OID: [DATE_OID, SMALLDATETIME_OID],
134+
# timestamp 类型
135+
TIMESTAMP_OID: [TIMESTAMP_OID, SMALLDATETIME_OID],
136+
TIMESTAMPTZ_OID: [TIMESTAMPTZ_OID],
137+
# 其他类型保持一对一映射
138+
}
139+
140+
141+
def is_compatible_oid(expected_oid: int, actual_oid: int) -> bool:
142+
"""
143+
检查两个 OID 是否兼容
144+
145+
用于 GaussDB 场景下的类型比较,考虑 OID 别名。
146+
147+
Args:
148+
expected_oid: 期望的 OID
149+
actual_oid: 实际的 OID
150+
151+
Returns:
152+
是否兼容
153+
"""
154+
if expected_oid == actual_oid:
155+
return True
156+
157+
# 检查别名映射
158+
aliases = GAUSSDB_OID_ALIASES.get(expected_oid, [expected_oid])
159+
return actual_oid in aliases
160+
161+
162+
def get_oid_name(oid: int) -> str:
163+
"""
164+
获取 OID 对应的类型名称
165+
166+
Args:
167+
oid: 类型 OID
168+
169+
Returns:
170+
类型名称字符串
171+
"""
172+
# 反向查找 OID 常量名
173+
for name, value in globals().items():
174+
if name.endswith("_OID") and value == oid:
175+
return name.replace("_OID", "").lower()
176+
return f"oid_{oid}"

gaussdb/gaussdb/_typeinfo.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from . import sql
1717
from .abc import AdaptContext, Query
1818
from .rows import dict_row
19+
from ._oids import GAUSSDB_OID_ALIASES
1920
from ._compat import TypeAlias, TypeVar
2021
from ._typemod import TypeModifier
2122
from ._encodings import conn_encoding
@@ -209,6 +210,56 @@ def get_precision(self, fmod: int) -> int | None:
209210
def get_scale(self, fmod: int) -> int | None:
210211
return self.typemod.get_scale(fmod)
211212

213+
@classmethod
214+
def fetch_runtime_oid(cls, conn: Any, typename: str) -> int | None:
215+
"""
216+
运行时获取类型 OID
217+
218+
从数据库查询正确的 OID,处理 GaussDB 差异。
219+
220+
Args:
221+
conn: 数据库连接
222+
typename: 类型名称
223+
224+
Returns:
225+
类型 OID,查询失败返回 None
226+
"""
227+
try:
228+
from .connection import Connection
229+
230+
if isinstance(conn, Connection):
231+
result = conn.execute(
232+
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
233+
).fetchone()
234+
else:
235+
# AsyncConnection
236+
import asyncio
237+
238+
async def _fetch():
239+
result = await conn.execute(
240+
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
241+
)
242+
return await result.fetchone()
243+
244+
result = asyncio.run(_fetch())
245+
246+
return result[0] if result else None
247+
except Exception:
248+
return None
249+
250+
@classmethod
251+
def get_compatible_oids(cls, base_oid: int) -> list[int]:
252+
"""
253+
获取兼容的 OID 列表
254+
255+
Args:
256+
base_oid: 基础 OID
257+
258+
Returns:
259+
包含基础 OID 及其别名的列表
260+
"""
261+
return GAUSSDB_OID_ALIASES.get(base_oid, [base_oid])
262+
212263

213264
class TypesRegistry:
214265
"""

tests/test_gaussdb_dbapi20.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,30 @@ def fake_connect(conninfo, *, timeout=0.0):
161161
def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype):
162162
with pytest.raises(exctype):
163163
gaussdb.connect(*args, **kwargs)
164+
165+
166+
class TestTypeCode:
167+
"""type_code 兼容性测试"""
168+
169+
def test_type_code_comparison(self, conn):
170+
"""测试 type_code 比较"""
171+
cur = conn.cursor()
172+
cur.execute("select 1::int, 'hello'::text")
173+
174+
desc = cur.description
175+
176+
# 验证 type_code 是整数
177+
for col in desc:
178+
assert isinstance(col.type_code, int)
179+
assert col.type_code > 0
180+
181+
def test_type_code_date(self, conn):
182+
"""测试日期类型 type_code"""
183+
cur = conn.cursor()
184+
cur.execute("select current_date")
185+
186+
type_code = cur.description[0].type_code
187+
188+
# GaussDB 可能返回不同的 OID
189+
# 只要是有效的 OID 即可
190+
assert type_code > 0, f"Invalid type_code: {type_code}"

tests/test_typeinfo.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,14 @@
33
import gaussdb
44
from gaussdb import sql
55
from gaussdb.pq import TransactionStatus
6+
from gaussdb._oids import (
7+
DATE_OID,
8+
GAUSSDB_OID_ALIASES,
9+
SMALLDATETIME_OID,
10+
TIMESTAMP_OID,
11+
get_oid_name,
12+
is_compatible_oid,
13+
)
614
from gaussdb.types import TypeInfo
715
from gaussdb.types.enum import EnumInfo
816
from gaussdb.types.range import RangeInfo
@@ -208,3 +216,34 @@ def test_registry_isolated():
208216
print(f"orig={orig},tinfo={tinfo},r={r},tdummy={tdummy}")
209217
assert r[25] is r["dummy"] is tdummy
210218
assert orig[25] is r["text"] is tinfo
219+
220+
221+
class TestOidCompatibility:
222+
"""OID 兼容性测试"""
223+
224+
def test_same_oid_compatible(self):
225+
"""相同 OID 应兼容"""
226+
assert is_compatible_oid(DATE_OID, DATE_OID)
227+
assert is_compatible_oid(TIMESTAMP_OID, TIMESTAMP_OID)
228+
229+
def test_alias_oid_compatible(self):
230+
"""别名 OID 应兼容"""
231+
# 如果 smalldatetime 是 date 的别名
232+
if SMALLDATETIME_OID in GAUSSDB_OID_ALIASES.get(DATE_OID, []):
233+
assert is_compatible_oid(DATE_OID, SMALLDATETIME_OID)
234+
235+
def test_different_oid_not_compatible(self):
236+
"""不同类型 OID 不兼容"""
237+
assert not is_compatible_oid(DATE_OID, 23) # int4
238+
239+
def test_get_oid_name(self):
240+
"""测试获取 OID 名称"""
241+
assert get_oid_name(DATE_OID) == "date"
242+
assert get_oid_name(TIMESTAMP_OID) == "timestamp"
243+
244+
def test_runtime_oid_fetch(self, conn):
245+
"""测试运行时 OID 查询"""
246+
oid = TypeInfo.fetch_runtime_oid(conn, "date")
247+
if oid is not None:
248+
assert isinstance(oid, int)
249+
assert oid > 0

0 commit comments

Comments
 (0)