Skip to content

Commit e1aa308

Browse files
author
liuhuiqi.7
committed
feat(crypto): encryption
1 parent 9de515e commit e1aa308

3 files changed

Lines changed: 51 additions & 37 deletions

File tree

volcenginesdkarkruntime/_client.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,19 @@ def __init__(
142142
def _get_endpoint_sts_token(self, endpoint_id: str, project_name: str = None):
143143
if self._sts_token_manager is None:
144144
if self.ak is None or self.sk is None:
145-
raise ArkAPIError("must set ak and sk before get endpoint token.")
146-
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
145+
raise ArkAPIError(
146+
"must set ak and sk before get endpoint token.")
147+
self._sts_token_manager = StsTokenManager(
148+
self.ak, self.sk, self.region)
147149
resource_type: str = self.get_resource_type_by_endpoint_id(endpoint_id)
148150
if resource_type == _PRESETENDPOINT_RESOURCE_TYPE and (project_name is None or project_name.strip() == ""):
149-
raise ArkAPIError("must set project_name when get preset endpoint token.")
151+
raise ArkAPIError(
152+
"must set project_name when get preset endpoint token.")
150153
return self._sts_token_manager.get(endpoint_id, resource_type=resource_type, project_name=project_name)
151154

152155
def _get_endpoint_certificate(
153156
self, endpoint_id: str
154-
) -> Tuple[key_agreement_client, str, str, float]:
157+
) -> key_agreement_client:
155158
if self._certificate_manager is None:
156159
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
157160
if (
@@ -171,8 +174,10 @@ def _get_endpoint_certificate(
171174
def _get_bot_sts_token(self, bot_id: str):
172175
if self._sts_token_manager is None:
173176
if self.ak is None or self.sk is None:
174-
raise ArkAPIError("must set ak and sk before get endpoint token.")
175-
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
177+
raise ArkAPIError(
178+
"must set ak and sk before get endpoint token.")
179+
self._sts_token_manager = StsTokenManager(
180+
self.ak, self.sk, self.region)
176181
return self._sts_token_manager.get(bot_id, resource_type="bot")
177182

178183
@property
@@ -288,20 +293,24 @@ def __init__(
288293
def _get_endpoint_sts_token(self, endpoint_id: str):
289294
if self._sts_token_manager is None:
290295
if self.ak is None or self.sk is None:
291-
raise ArkAPIError("must set ak and sk before get endpoint token.")
292-
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
296+
raise ArkAPIError(
297+
"must set ak and sk before get endpoint token.")
298+
self._sts_token_manager = StsTokenManager(
299+
self.ak, self.sk, self.region)
293300
return self._sts_token_manager.get(endpoint_id)
294301

295302
def _get_bot_sts_token(self, bot_id: str):
296303
if self._sts_token_manager is None:
297304
if self.ak is None or self.sk is None:
298-
raise ArkAPIError("must set ak and sk before get endpoint token.")
299-
self._sts_token_manager = StsTokenManager(self.ak, self.sk, self.region)
305+
raise ArkAPIError(
306+
"must set ak and sk before get endpoint token.")
307+
self._sts_token_manager = StsTokenManager(
308+
self.ak, self.sk, self.region)
300309
return self._sts_token_manager.get(bot_id, resource_type="bot")
301310

302311
def _get_endpoint_certificate(
303312
self, endpoint_id: str
304-
) -> Tuple[key_agreement_client, str, str, float]:
313+
) -> key_agreement_client:
305314
if self._certificate_manager is None:
306315
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
307316
if (
@@ -414,7 +423,8 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE, project
414423
)
415424

416425
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE, project_name: str = None) -> str:
417-
self._refresh(ep, resource_type=resource_type, project_name=project_name)
426+
self._refresh(ep, resource_type=resource_type,
427+
project_name=project_name)
418428
return self._endpoint_sts_tokens[ep][0]
419429

420430
def _load_api_key(
@@ -455,9 +465,7 @@ def __init__(
455465
base_url: str | URL = BASE_URL,
456466
api_key: str | None = None,
457467
):
458-
self._certificate_manager: Dict[
459-
str, Tuple[key_agreement_client, str, str, float]
460-
] = {}
468+
self._certificate_manager: Dict[str, key_agreement_client] = {}
461469

462470
# local cache prepare
463471
self._init_local_cert_cache()
@@ -503,7 +511,8 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
503511
)
504512
if self._aicc_enabled:
505513
get_endpoint_certificate_request = (
506-
volcenginesdkark.GetEndpointCertificateRequest(id=ep, type="AICCv0.1")
514+
volcenginesdkark.GetEndpointCertificateRequest(
515+
id=ep, type="AICCv0.1")
507516
)
508517
try:
509518
resp: volcenginesdkark.GetEndpointCertificateResponse = (
@@ -519,7 +528,8 @@ def _load_cert_by_ak_sk(self, ep: str) -> str:
519528
return resp.pca_instance_certificate
520529

521530
def _sync_load_cert_by_auth(self, ep: str) -> str:
522-
try: # try to make request with session header (used for header statistic)
531+
# try to make request with session header (used for header statistic)
532+
try:
523533
req_body = {"model": ep}
524534
if self._aicc_enabled:
525535
req_body["type"] = "AICCv0.1"
@@ -574,22 +584,16 @@ def _init_local_cert_cache(self):
574584
% (self._cert_storage_path, e)
575585
)
576586

577-
def get(self, ep: str) -> Tuple[key_agreement_client, str, str, float]:
578-
if ep not in self._certificate_manager:
579-
cert_pem = self._load_cert_locally(ep)
580-
if cert_pem is None:
581-
if self.cert_path is not None:
582-
cert_pem = self._load_cert_by_cert_path()
583-
elif self._ark_client_enabled:
584-
cert_pem = self._sync_load_cert_by_auth(ep)
585-
else:
586-
cert_pem = self._load_cert_by_ak_sk(ep)
587-
self._save_cert_to_file(ep, cert_pem)
588-
ring, key, exp_time = get_cert_info(cert_pem)
589-
self._certificate_manager[ep] = (
590-
key_agreement_client(certificate_pem_string=cert_pem),
591-
ring,
592-
key,
593-
exp_time,
594-
)
595-
return self._certificate_manager[ep]
587+
def get(self, ep: str) -> key_agreement_client:
588+
if ep in self._certificate_manager and not self._certificate_manager[ep].is_expired():
589+
return self._certificate_manager[ep]
590+
cert_pem = self._load_cert_locally(ep)
591+
if cert_pem is None:
592+
if self.cert_path is not None:
593+
cert_pem = self._load_cert_by_cert_path()
594+
elif self._ark_client_enabled:
595+
cert_pem = self._sync_load_cert_by_auth(ep)
596+
else:
597+
cert_pem = self._load_cert_by_ak_sk(ep)
598+
self._save_cert_to_file(ep, cert_pem)
599+
self._certificate_manager[ep] = key_agreement_client(certificate_pem_string=cert_pem)

volcenginesdkarkruntime/_utils/_key_agreement.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,11 @@ def init_cert_ring_key_id(self) -> None:
235235
self._key_id = ""
236236
except Exception:
237237
pass
238+
239+
def get_cert_ring_key_id(self) -> Tuple[str, str]:
240+
"""get_cert_ring_key_id get ring id and key id from cert"""
241+
return self._ring_id, self._key_id
242+
243+
def get_cert_expiration_time(self) -> float:
244+
"""get_cert_expiration_time get cert expiration time"""
245+
return self._not_valid_after_utc

volcenginesdkarkruntime/resources/encryption.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,9 @@ def _content_encryption(args, kwargs):
154154
model: str = kwargs.get("model", "")
155155
messages = deepcopy(kwargs["messages"])
156156
ark_client = args[0]._client
157-
client, ring_id, key_id, exp_time = ark_client._get_endpoint_certificate(model)
157+
client = ark_client._get_endpoint_certificate(model)
158+
ring_id, key_id = client.get_cert_ring_key_id()
159+
exp_time = client.get_cert_expiration_time()
158160
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
159161
extra_headers["X-Session-Token"] = session_token
160162
_process_messages(

0 commit comments

Comments
 (0)