2424 _DEFAULT_MANDATORY_REFRESH_TIMEOUT ,
2525 _DEFAULT_STS_TIMEOUT ,
2626 _DEFAULT_RESOURCE_TYPE ,
27- DEFAULT_TIMEOUT
27+ DEFAULT_TIMEOUT ,
2828)
2929from ._streaming import Stream
3030
@@ -84,7 +84,9 @@ def __init__(
8484 self .api_key = api_key
8585 self .region = region
8686
87- assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
87+ assert (api_key is not None ) or (ak is not None and sk is not None ), (
88+ "you need to support api_key or ak&sk"
89+ )
8890
8991 super ().__init__ (
9092 base_url = base_url ,
@@ -120,10 +122,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
120122 def _get_endpoint_certificate (self , endpoint_id : str ) -> key_agreement_client :
121123 if self ._certificate_manager is None :
122124 cert_path = os .environ .get ("E2E_CERTIFICATE_PATH" )
123- if (self .ak is None or self .sk is None ) and cert_path is None and self .api_key is None :
124- raise ArkAPIError ("must set (api_key) or (ak and sk) \
125- or (E2E_CERTIFICATE_PATH) before get endpoint token." )
126- self ._certificate_manager = E2ECertificateManager (self .ak , self .sk , self .region , self ._base_url , self .api_key )
125+ if (
126+ (self .ak is None or self .sk is None )
127+ and cert_path is None
128+ and self .api_key is None
129+ ):
130+ raise ArkAPIError (
131+ "must set (api_key) or (ak and sk) \
132+ or (E2E_CERTIFICATE_PATH) before get endpoint token."
133+ )
134+ self ._certificate_manager = E2ECertificateManager (
135+ self .ak , self .sk , self .region , self ._base_url , self .api_key
136+ )
127137 return self ._certificate_manager .get (endpoint_id )
128138
129139 def _get_bot_sts_token (self , bot_id : str ):
@@ -142,6 +152,7 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
142152 with self .model_breaker_lock :
143153 return self .model_breaker_map [model_name ]
144154
155+
145156class AsyncArk (AsyncAPIClient ):
146157 chat : resources .AsyncChat
147158 bot_chat : resources .AsyncBotChat
@@ -168,15 +179,15 @@ def __init__(
168179 ) -> None :
169180 """init async ark client, this client is thread unsafe
170181
171- Args:
172- ak: access key id
173- sk: secret access key
174- api_key: api key,this api key will not be refreshed
175- timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
176- max_retries: times of retry when request failed. default 1
177- http_client: specify customized http_client
178- Returns:
179- async ark client
182+ Args:
183+ ak: access key id
184+ sk: secret access key
185+ api_key: api key,this api key will not be refreshed
186+ timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
187+ max_retries: times of retry when request failed. default 1
188+ http_client: specify customized http_client
189+ Returns:
190+ async ark client
180191 """
181192
182193 if ak is None :
@@ -191,7 +202,9 @@ def __init__(
191202 self .api_key = api_key
192203 self .region = region
193204
194- assert (api_key is not None ) or (ak is not None and sk is not None ), "you need to support api_key or ak&sk"
205+ assert (api_key is not None ) or (ak is not None and sk is not None ), (
206+ "you need to support api_key or ak&sk"
207+ )
195208
196209 super ().__init__ (
197210 base_url = base_url ,
@@ -227,10 +240,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
227240 def _get_endpoint_certificate (self , endpoint_id : str ) -> key_agreement_client :
228241 if self ._certificate_manager is None :
229242 cert_path = os .environ .get ("E2E_CERTIFICATE_PATH" )
230- if (self .ak is None or self .sk is None ) and cert_path is None and self .api_key is None :
231- raise ArkAPIError ("must set (api_key) or (ak and sk) \
232- or (E2E_CERTIFICATE_PATH) before get endpoint token." )
233- self ._certificate_manager = E2ECertificateManager (self .ak , self .sk , self .region , self ._base_url , self .api_key )
243+ if (
244+ (self .ak is None or self .sk is None )
245+ and cert_path is None
246+ and self .api_key is None
247+ ):
248+ raise ArkAPIError (
249+ "must set (api_key) or (ak and sk) \
250+ or (E2E_CERTIFICATE_PATH) before get endpoint token."
251+ )
252+ self ._certificate_manager = E2ECertificateManager (
253+ self .ak , self .sk , self .region , self ._base_url , self .api_key
254+ )
234255 return self ._certificate_manager .get (endpoint_id )
235256
236257 @property
@@ -252,7 +273,9 @@ class StsTokenManager(object):
252273 _mandatory_refresh_timeout : int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
253274
254275 def __init__ (self , ak : str , sk : str , region : str ):
255- self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (lambda : ("" , 0 ))
276+ self ._endpoint_sts_tokens : Dict [str , Tuple [str , int ]] = defaultdict (
277+ lambda : ("" , 0 )
278+ )
256279 self ._refresh_lock = threading .Lock ()
257280
258281 import volcenginesdkcore
@@ -272,10 +295,19 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
272295
273296 return self ._endpoint_sts_tokens [ep ][1 ] - time .time () < refresh_in
274297
275- def _protected_refresh (self , ep : str , ttl : int = _DEFAULT_STS_TIMEOUT , is_mandatory : bool = False ,
276- resource_type : str = _DEFAULT_RESOURCE_TYPE ):
298+ def _protected_refresh (
299+ self ,
300+ ep : str ,
301+ ttl : int = _DEFAULT_STS_TIMEOUT ,
302+ is_mandatory : bool = False ,
303+ resource_type : str = _DEFAULT_RESOURCE_TYPE ,
304+ ):
277305 if ttl < self ._advisory_refresh_timeout * 2 :
278- raise ArkAPIError ("ttl should not be under {} seconds." .format (self ._advisory_refresh_timeout * 2 ))
306+ raise ArkAPIError (
307+ "ttl should not be under {} seconds." .format (
308+ self ._advisory_refresh_timeout * 2
309+ )
310+ )
279311
280312 try :
281313 api_key , expired_time = self ._load_api_key (
@@ -301,7 +333,9 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
301333 ep , self ._mandatory_refresh_timeout
302334 )
303335
304- self ._protected_refresh (ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type )
336+ self ._protected_refresh (
337+ ep , is_mandatory = is_mandatory_refresh , resource_type = resource_type
338+ )
305339 return
306340 finally :
307341 self ._refresh_lock .release ()
@@ -310,14 +344,20 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
310344 if not self ._need_refresh (ep , self ._mandatory_refresh_timeout ):
311345 return
312346
313- self ._protected_refresh (ep , is_mandatory = True , resource_type = resource_type )
347+ self ._protected_refresh (
348+ ep , is_mandatory = True , resource_type = resource_type
349+ )
314350
315351 def get (self , ep : str , resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> str :
316352 self ._refresh (ep , resource_type = resource_type )
317353 return self ._endpoint_sts_tokens [ep ][0 ]
318354
319- def _load_api_key (self , ep : str , duration_seconds : int ,
320- resource_type : str = _DEFAULT_RESOURCE_TYPE ) -> Tuple [str , int ]:
355+ def _load_api_key (
356+ self ,
357+ ep : str ,
358+ duration_seconds : int ,
359+ resource_type : str = _DEFAULT_RESOURCE_TYPE ,
360+ ) -> Tuple [str , int ]:
321361 get_api_key_request = volcenginesdkark .GetApiKeyRequest (
322362 duration_seconds = duration_seconds ,
323363 resource_type = resource_type ,
@@ -331,19 +371,26 @@ def _load_api_key(self, ep: str, duration_seconds: int,
331371
332372
333373class E2ECertificateManager (object ):
334-
335- class CertificateResponse ():
374+ class CertificateResponse :
336375 Certificate : str
337376 """The certificate content."""
338377
339- def __init__ (self , ak : str , sk : str , region : str , base_url : str | URL = BASE_URL , api_key : str | None = None ):
378+ def __init__ (
379+ self ,
380+ ak : str ,
381+ sk : str ,
382+ region : str ,
383+ base_url : str | URL = BASE_URL ,
384+ api_key : str | None = None ,
385+ ):
340386 self ._certificate_manager : Dict [str , key_agreement_client ] = {}
341387
342388 # local cache prepare
343389 self ._init_local_cert_cache ()
344390
345391 # api instance prepare
346392 import volcenginesdkcore
393+
347394 configuration = volcenginesdkcore .Configuration ()
348395 configuration .ak = ak
349396 configuration .sk = sk
@@ -365,38 +412,47 @@ def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL
365412 api_key = api_key ,
366413 )
367414 self ._e2e_uri = "/e2e/get/certificate"
368- self ._x_session_token = {' X-Session-Token' : self ._e2e_uri }
415+ self ._x_session_token = {" X-Session-Token" : self ._e2e_uri }
369416
370417 def _load_cert_by_cert_path (self ) -> str :
371- with open (self .cert_path , 'r' ) as f :
418+ with open (self .cert_path , "r" ) as f :
372419 cert_pem = f .read ()
373420 return cert_pem
374421
375422 def _load_cert_by_ak_sk (self , ep : str ) -> str :
376- get_endpoint_certificate_request = volcenginesdkark . GetEndpointCertificateRequest (
377- id = ep
423+ get_endpoint_certificate_request = (
424+ volcenginesdkark . GetEndpointCertificateRequest ( id = ep )
378425 )
379426 try :
380- resp : volcenginesdkark .GetEndpointCertificateResponse = self .api_instance .get_endpoint_certificate (
381- get_endpoint_certificate_request )
427+ resp : volcenginesdkark .GetEndpointCertificateResponse = (
428+ self .api_instance .get_endpoint_certificate (
429+ get_endpoint_certificate_request
430+ )
431+ )
382432 except ApiException as e :
383- raise ArkAPIError ("Getting model vendor encryption certificate failed: %s\n " % e )
433+ raise ArkAPIError (
434+ "Getting model vendor encryption certificate failed: %s\n " % e
435+ )
384436
385437 return resp .pca_instance_certificate
386438
387439 def _sync_load_cert_by_auth (self , ep : str ) -> str :
388440 try : # try to make request with session header (used for header statistic)
389- resp = self .client .post (self ._e2e_uri , options = {"headers" : self ._x_session_token },
390- body = {"model" : ep }, cast_to = self .CertificateResponse )
441+ resp = self .client .post (
442+ self ._e2e_uri ,
443+ options = {"headers" : self ._x_session_token },
444+ body = {"model" : ep },
445+ cast_to = self .CertificateResponse ,
446+ )
391447 except Exception as e :
392448 raise ArkAPIError ("Getting Certificate failed: %s\n " % e )
393- if ' error' in resp :
394- raise ArkAPIError ("Getting Certificate failed: %s\n " % resp [' error' ])
395- return resp [' Certificate' ]
449+ if " error" in resp :
450+ raise ArkAPIError ("Getting Certificate failed: %s\n " % resp [" error" ])
451+ return resp [" Certificate" ]
396452
397453 def _save_cert_to_file (self , ep : str , cert_pem : str ):
398454 cert_file_path = os .path .join (self ._cert_storage_path , f"{ ep } .pem" )
399- with open (cert_file_path , 'w' ) as f :
455+ with open (cert_file_path , "w" ) as f :
400456 f .write (cert_pem )
401457
402458 def _load_cert_locally (self , ep : str ) -> str | None :
@@ -406,7 +462,7 @@ def _load_cert_locally(self, ep: str) -> str | None:
406462 current_time = time .time ()
407463 time_difference = current_time - last_modified_time
408464 if time_difference <= self ._cert_expiration_seconds :
409- with open (cert_file_path , 'r' ) as f :
465+ with open (cert_file_path , "r" ) as f :
410466 return f .read ()
411467 else :
412468 os .remove (cert_file_path )
0 commit comments