Skip to content

Commit 793feb9

Browse files
chore: add rds connect utils
1 parent 12b081c commit 793feb9

3 files changed

Lines changed: 73 additions & 52 deletions

File tree

volcenginesdkcore/feature/rds/connect_utils.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,72 +5,77 @@
55
This module provides utilities for generating authentication tokens for RDS MySQL database connections.
66
"""
77

8-
from volcenginesdkcore.signv4 import SignerV4
9-
from volcenginesdkcore.endpoint.providers.default_provider import DefaultEndpointProvider
8+
from volcenginesdkcore.endpoint.providers.standard_provider import StandardEndpointResolver
9+
from volcenginesdkcore.interceptor import InterceptorChain, InterceptorContext, SignRequestInterceptor, \
10+
ResolveEndpointInterceptor
11+
from volcenginesdkcore.interceptor import Request
1012

13+
DEFAULT_SERVICE = 'rds_mysql'
14+
DEFAULT_API_VERSION = '2022-01-01'
15+
DEFAULT_API = 'ConnectDatabase'
16+
DEFAULT_EXPIRES = 900
1117

12-
def build_auth_token(credentials, db_user, instance_id, region, expires=None):
18+
19+
def build_auth_token(api_client, db_user, instance_id, expires=None):
1320
"""
1421
Build an authentication token (presigned URL) for connecting to RDS MySQL database.
1522
16-
:param credentials: CredentialValue object with ak, sk, and optional session_token
23+
:param api_client: ApiClient instance
1724
:param db_user: Database username
1825
:param instance_id: RDS instance ID
19-
:param region: Service region (e.g., 'cn-beijing')
2026
:param expires: Token expiration time in seconds (default: 900, i.e., 15 minutes)
2127
:return: Presigned URL string for database authentication
2228
:raises ValueError: If required parameters are missing or invalid
2329
"""
24-
# Validate inputs
25-
if credentials is None:
26-
raise ValueError("credentials must not be None")
30+
# Validate api_client
31+
if api_client is None:
32+
raise ValueError("api_client must not be None")
2733

28-
if not hasattr(credentials, 'ak') or not credentials.ak:
29-
raise ValueError("credentials.ak must not be empty")
34+
configuration = api_client.configuration
35+
region = configuration.region
3036

31-
if not hasattr(credentials, 'sk') or not credentials.sk:
32-
raise ValueError("credentials.sk must not be empty")
37+
# Validate inputs
38+
if not region:
39+
raise ValueError("region must not be empty")
3340

3441
if not db_user:
3542
raise ValueError("db_user must not be empty")
3643

3744
if not instance_id:
3845
raise ValueError("instance_id must not be empty")
3946

40-
if not region:
41-
raise ValueError("region must not be empty")
42-
4347
# Set default expiration time
4448
if expires is None or expires <= 0:
45-
expires = 900 # 15 minutes
46-
47-
# Service configuration
48-
service = 'rds_mysql'
49-
50-
# Get endpoint
51-
endpoint_provider = DefaultEndpointProvider()
52-
resolved_endpoint = endpoint_provider.endpoint_for(service, region)
53-
host = resolved_endpoint.host
49+
expires = DEFAULT_EXPIRES
5450

5551
# Build query parameters
5652
query = {
57-
'Action': 'ConnectDatabase',
58-
'Version': '2022-01-01',
53+
'Action': DEFAULT_API,
54+
'Version': DEFAULT_API_VERSION,
5955
'X-Expires': str(expires),
6056
'DBUser': db_user,
6157
'InstanceId': instance_id,
6258
}
6359

64-
# Sign the URL
65-
signed_query = SignerV4.sign_url(
66-
path='/',
67-
method='GET',
68-
query=query,
69-
ak=credentials.ak,
70-
sk=credentials.sk,
71-
region=region,
72-
service=service,
73-
session_token=getattr(credentials, 'session_token', None)
74-
)
75-
76-
return signed_query
60+
# Create Request with presign mode
61+
request = Request(configuration,
62+
resource_path='/{}/{}/{}/get/text_plain/'.format(DEFAULT_API, DEFAULT_API_VERSION,
63+
DEFAULT_SERVICE),
64+
method='GET',
65+
query_params=query)
66+
request.host = None # Force endpoint resolution by interceptor
67+
request.endpoint_provider = StandardEndpointResolver()
68+
request.service = DEFAULT_SERVICE
69+
request.is_presign = True
70+
71+
# Create interceptor chain:
72+
# ResolveEndpointInterceptor - resolves endpoint + scheme
73+
# SignRequestInterceptor - presign URL signing
74+
chain = InterceptorChain()
75+
chain.append_request_interceptor(ResolveEndpointInterceptor())
76+
chain.append_request_interceptor(SignRequestInterceptor())
77+
78+
context = InterceptorContext(request=request)
79+
context = chain.execute_request(context)
80+
81+
return '{url}?{query}'.format(url=context.request.url, query=context.request.signed_query)

volcenginesdkcore/interceptor/interceptors/request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def __init__(
6363
self.retryer = configuration.retryer
6464
self.credential_provider = configuration.credential_provider
6565

66+
# Presign support
67+
self.is_presign = False
68+
self.signed_query = None
69+
6670
self.runtime_options = None
6771
if hasattr(body, '_configuration') and isinstance(body._configuration, RuntimeOption):
6872
self.runtime_options = body._configuration

volcenginesdkcore/interceptor/interceptors/sign_request_interceptor.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,35 @@ def name(self):
1313
def intercept(self, context):
1414
# 新增代码。处理assume_role和assume_role_oidc和assume_role_saml
1515
if context.request.credential_provider is not None:
16-
credentials = context.request.credential_provider.get_credentials() # 这会调用 _assume_role_oidc() 方法获取临时凭证
16+
credentials = context.request.credential_provider.get_credentials() # 这会调用 _assume_role_oidc() 方法获取临时凭证
1717
context.request.ak = credentials.ak
1818
context.request.sk = credentials.sk
1919
context.request.session_token = credentials.session_token
2020

21-
self.update_params_for_auth(host=context.request.host, path=context.request.true_path,
22-
method=context.request.method,
23-
headers=context.request.header_params,
24-
querys=context.request.query_params,
25-
auth_settings=context.request.auth_settings,
26-
body=context.request.body,
27-
post_params=context.request.post_params,
28-
service=context.request.service,
29-
ak=context.request.ak,
30-
sk=context.request.sk,
31-
session_token=context.request.session_token,
32-
region=context.request.region)
21+
if context.request.is_presign:
22+
context.request.signed_query = SignerV4.sign_url(
23+
path=context.request.true_path,
24+
method=context.request.method,
25+
query=context.request.query_params,
26+
ak=context.request.ak,
27+
sk=context.request.sk,
28+
region=context.request.region,
29+
service=context.request.service,
30+
session_token=context.request.session_token,
31+
)
32+
else:
33+
self.update_params_for_auth(host=context.request.host, path=context.request.true_path,
34+
method=context.request.method,
35+
headers=context.request.header_params,
36+
querys=context.request.query_params,
37+
auth_settings=context.request.auth_settings,
38+
body=context.request.body,
39+
post_params=context.request.post_params,
40+
service=context.request.service,
41+
ak=context.request.ak,
42+
sk=context.request.sk,
43+
session_token=context.request.session_token,
44+
region=context.request.region)
3345
return context
3446

3547
@staticmethod

0 commit comments

Comments
 (0)