Skip to content

Commit c1e01e9

Browse files
committed
Merge 'feature/support_rds_auth' into 'master'
chore: add rds connect utils See merge request: !1055
2 parents 2ed6b4e + 1005f3c commit c1e01e9

7 files changed

Lines changed: 214 additions & 14 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# coding: utf-8
2+
"""
3+
Feature modules for specific service utilities.
4+
5+
This package contains specialized utility modules for various Volcengine services.
6+
"""
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# coding: utf-8
2+
from volcenginesdkcore.feature.rds.connect_utils import build_auth_token
3+
4+
__all__ = ['build_auth_token']
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# coding: utf-8
2+
"""
3+
RDS MySQL connection authentication utilities.
4+
5+
This module provides utilities for generating authentication tokens for RDS MySQL database connections.
6+
"""
7+
8+
from volcenginesdkcore.endpoint.providers.standard_provider import StandardEndpointResolver
9+
from volcenginesdkcore.interceptor import InterceptorChain, InterceptorContext, SignRequestInterceptor, \
10+
ResolveEndpointInterceptor
11+
from volcenginesdkcore.interceptor import Request
12+
13+
DEFAULT_SERVICE = 'rds_mysql'
14+
DEFAULT_API_VERSION = '2022-01-01'
15+
DEFAULT_API = 'ConnectDatabase'
16+
DEFAULT_EXPIRES = 900
17+
18+
19+
def build_auth_token(api_client, db_user, instance_id, expires=None):
20+
"""
21+
Build an authentication token (presigned URL) for connecting to RDS MySQL database.
22+
23+
:param api_client: ApiClient instance
24+
:param db_user: Database username
25+
:param instance_id: RDS instance ID
26+
:param expires: Token expiration time in seconds (default: 900, i.e., 15 minutes)
27+
:return: Presigned URL string for database authentication
28+
:raises ValueError: If required parameters are missing or invalid
29+
"""
30+
# Validate api_client
31+
if api_client is None:
32+
raise ValueError("api_client must not be None")
33+
34+
configuration = api_client.configuration
35+
region = configuration.region
36+
37+
# Validate inputs
38+
if not region:
39+
raise ValueError("region must not be empty")
40+
41+
if not db_user:
42+
raise ValueError("db_user must not be empty")
43+
44+
if not instance_id:
45+
raise ValueError("instance_id must not be empty")
46+
47+
# Set default expiration time
48+
if expires is None or expires <= 0:
49+
expires = DEFAULT_EXPIRES
50+
51+
# Build query parameters
52+
query = {
53+
'Action': DEFAULT_API,
54+
'Version': DEFAULT_API_VERSION,
55+
'X-Expires': str(expires),
56+
'DBUser': db_user,
57+
'InstanceId': instance_id,
58+
}
59+
60+
# Create Request with presign mode
61+
request = Request(configuration,
62+
resource_path='/{}/{}/{}/get/text_plain/'.format(DEFAULT_API, DEFAULT_API_VERSION,
63+
DEFAULT_SERVICE),
64+
method='GET',
65+
query_params=query)
66+
request.endpoint_provider = StandardEndpointResolver()
67+
request.service = DEFAULT_SERVICE
68+
request.is_presign = True
69+
70+
# Create interceptor chain:
71+
# ResolveEndpointInterceptor - resolves endpoint + scheme
72+
# SignRequestInterceptor - presign URL signing
73+
chain = InterceptorChain()
74+
chain.append_request_interceptor(ResolveEndpointInterceptor())
75+
chain.append_request_interceptor(SignRequestInterceptor())
76+
77+
context = InterceptorContext(request=request)
78+
context = chain.execute_request(context)
79+
80+
return '{url}?{query}'.format(url=context.request.url, query=context.request.signed_query)

volcenginesdkcore/interceptor/interceptors/request.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ def __init__(
6363
self.retryer = configuration.retryer
6464
self.credential_provider = configuration.credential_provider
6565

66+
# Presign support
67+
self.is_presign = False
68+
self.signed_query = None
69+
6670
self.runtime_options = None
6771
if hasattr(body, '_configuration') and isinstance(body._configuration, RuntimeOption):
6872
self.runtime_options = body._configuration

volcenginesdkcore/interceptor/interceptors/resolve_endpoint_interceptor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@ def intercept(self, context):
1919
context.request.host = endpoint_resolver.host
2020
prefix = endpoint_resolver.url_for(scheme)
2121
else:
22-
prefix = scheme + '://' + host
22+
if host.startswith('https://'):
23+
prefix = host
24+
context.request.host = host[len('https://'):]
25+
elif host.startswith('http://'):
26+
prefix = host
27+
context.request.host = host[len('http://'):]
28+
else:
29+
prefix = scheme + '://' + host
2330
context.request.url = prefix + context.request.true_path
2431
sdk_core_logger.debug_endpoint(
2532
"Using endpoint: %s", context.request.host

volcenginesdkcore/interceptor/interceptors/sign_request_interceptor.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,36 @@ def name(self):
1313
def intercept(self, context):
1414
# 新增代码。处理assume_role和assume_role_oidc和assume_role_saml
1515
if context.request.credential_provider is not None:
16-
credentials = context.request.credential_provider.get_credentials() # 这会调用 _assume_role_oidc() 方法获取临时凭证
16+
credentials = context.request.credential_provider.get_credentials() # 这会调用 _assume_role_oidc() 方法获取临时凭证
1717
context.request.ak = credentials.ak
1818
context.request.sk = credentials.sk
1919
context.request.session_token = credentials.session_token
2020

21-
self.update_params_for_auth(host=context.request.host, path=context.request.true_path,
22-
method=context.request.method,
23-
headers=context.request.header_params,
24-
querys=context.request.query_params,
25-
auth_settings=context.request.auth_settings,
26-
body=context.request.body,
27-
post_params=context.request.post_params,
28-
service=context.request.service,
29-
ak=context.request.ak,
30-
sk=context.request.sk,
31-
session_token=context.request.session_token,
32-
region=context.request.region)
21+
if context.request.is_presign:
22+
context.request.signed_query = SignerV4.sign_url(
23+
path=context.request.true_path,
24+
method=context.request.method,
25+
query=context.request.query_params,
26+
ak=context.request.ak,
27+
sk=context.request.sk,
28+
region=context.request.region,
29+
service=context.request.service,
30+
session_token=context.request.session_token,
31+
host=context.request.host,
32+
)
33+
else:
34+
self.update_params_for_auth(host=context.request.host, path=context.request.true_path,
35+
method=context.request.method,
36+
headers=context.request.header_params,
37+
querys=context.request.query_params,
38+
auth_settings=context.request.auth_settings,
39+
body=context.request.body,
40+
post_params=context.request.post_params,
41+
service=context.request.service,
42+
ak=context.request.ak,
43+
sk=context.request.sk,
44+
session_token=context.request.session_token,
45+
region=context.request.region)
3346
return context
3447

3548
@staticmethod

volcenginesdkcore/signv4.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,89 @@ def get_signing_secret_key_v4(sk, date, region, service):
9191
@staticmethod
9292
def hmac_sha256(key, msg):
9393
return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
94+
95+
@staticmethod
96+
def sign_url(path, method, query, ak, sk, region, service, session_token=None, host=None):
97+
"""
98+
Generate presigned URL query string (AWS Signature V4)
99+
100+
:param path: Request path
101+
:param method: HTTP method (GET, POST, etc.)
102+
:param query: Query parameters dict
103+
:param ak: Access Key
104+
:param sk: Secret Key
105+
:param region: Service region
106+
:param service: Service name
107+
:param session_token: Optional session token
108+
:param host: Optional host header to sign
109+
:return: Query string with signature
110+
"""
111+
format_date = datetime.datetime.utcnow().strftime("%Y%m%dT%H%M%SZ")
112+
date = format_date[:8]
113+
114+
# Build credential scope
115+
credential_scope = '/'.join([date, region, service, 'request'])
116+
117+
# Determine if host header should be signed
118+
sign_host = host is not None and host != ''
119+
120+
# Add required query parameters
121+
query = dict(query) # Make a copy to avoid modifying original
122+
query['X-Date'] = format_date
123+
query['X-NotSignBody'] = ''
124+
query['X-Credential'] = ak + '/' + credential_scope
125+
query['X-Algorithm'] = 'HMAC-SHA256'
126+
query['X-SignedHeaders'] = 'host' if sign_host else ''
127+
query['X-SignedQueries'] = ''
128+
129+
# Generate X-SignedQueries BEFORE adding X-Security-Token
130+
query['X-SignedQueries'] = ';'.join(sorted(query.keys()))
131+
signed_query_keys = set(query.keys())
132+
133+
# X-Security-Token must be added AFTER X-SignedQueries calculation
134+
if session_token:
135+
query['X-Security-Token'] = session_token
136+
137+
# Build canonical request
138+
body_hash = hashlib.sha256(b'').hexdigest()
139+
canonical_query_params = {k: v for k, v in query.items() if k in signed_query_keys}
140+
141+
if sign_host:
142+
canonical_request = '\n'.join([
143+
method,
144+
path,
145+
SignerV4.canonical_query(canonical_query_params),
146+
'host:' + host + '\n',
147+
'host',
148+
body_hash
149+
])
150+
else:
151+
canonical_request = '\n'.join([
152+
method,
153+
path,
154+
SignerV4.canonical_query(canonical_query_params),
155+
'\n',
156+
'',
157+
body_hash
158+
])
159+
sdk_core_logger.debug_sign("[sign_url] canonical_request:\n%s", canonical_request)
160+
161+
# Build string to sign
162+
signing_str = '\n'.join([
163+
'HMAC-SHA256',
164+
format_date,
165+
credential_scope,
166+
hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()
167+
])
168+
sdk_core_logger.debug_sign("[sign_url] string_to_sign:\n%s", signing_str)
169+
170+
# Calculate signature
171+
signing_key = SignerV4.get_signing_secret_key_v4(sk, date, region, service)
172+
signature = hmac.new(signing_key, signing_str.encode('utf-8'), hashlib.sha256).hexdigest()
173+
sdk_core_logger.debug_sign("[sign_url] calculated signature: %s", signature)
174+
175+
# Add signature to query
176+
query['X-Signature'] = signature
177+
178+
# Return encoded query string
179+
return urlencode(sorted(query.items()))

0 commit comments

Comments
 (0)