Skip to content

Commit b27f80a

Browse files
author
潘婉宁
committed
feat: support thinking
1 parent 655fa04 commit b27f80a

26 files changed

Lines changed: 497 additions & 385 deletions

volcenginesdkarkruntime/_base_client.py

Lines changed: 195 additions & 187 deletions
Large diffs are not rendered by default.

volcenginesdkarkruntime/_client.py

Lines changed: 101 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_DEFAULT_MANDATORY_REFRESH_TIMEOUT,
2525
_DEFAULT_STS_TIMEOUT,
2626
_DEFAULT_RESOURCE_TYPE,
27-
DEFAULT_TIMEOUT
27+
DEFAULT_TIMEOUT,
2828
)
2929
from ._streaming import Stream
3030

@@ -84,7 +84,9 @@ def __init__(
8484
self.api_key = api_key
8585
self.region = region
8686

87-
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
87+
assert (api_key is not None) or (ak is not None and sk is not None), (
88+
"you need to support api_key or ak&sk"
89+
)
8890

8991
super().__init__(
9092
base_url=base_url,
@@ -120,10 +122,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
120122
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
121123
if self._certificate_manager is None:
122124
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
123-
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
124-
raise ArkAPIError("must set (api_key) or (ak and sk) \
125-
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
126-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
125+
if (
126+
(self.ak is None or self.sk is None)
127+
and cert_path is None
128+
and self.api_key is None
129+
):
130+
raise ArkAPIError(
131+
"must set (api_key) or (ak and sk) \
132+
or (E2E_CERTIFICATE_PATH) before get endpoint token."
133+
)
134+
self._certificate_manager = E2ECertificateManager(
135+
self.ak, self.sk, self.region, self._base_url, self.api_key
136+
)
127137
return self._certificate_manager.get(endpoint_id)
128138

129139
def _get_bot_sts_token(self, bot_id: str):
@@ -142,6 +152,7 @@ def get_model_breaker(self, model_name: str) -> ModelBreaker:
142152
with self.model_breaker_lock:
143153
return self.model_breaker_map[model_name]
144154

155+
145156
class AsyncArk(AsyncAPIClient):
146157
chat: resources.AsyncChat
147158
bot_chat: resources.AsyncBotChat
@@ -168,15 +179,15 @@ def __init__(
168179
) -> None:
169180
"""init async ark client, this client is thread unsafe
170181
171-
Args:
172-
ak: access key id
173-
sk: secret access key
174-
api_key: api key,this api key will not be refreshed
175-
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
176-
max_retries: times of retry when request failed. default 1
177-
http_client: specify customized http_client
178-
Returns:
179-
async ark client
182+
Args:
183+
ak: access key id
184+
sk: secret access key
185+
api_key: api key,this api key will not be refreshed
186+
timeout: timeout of client. default httpx.Timeout(timeout=60.0, connect=60.0)
187+
max_retries: times of retry when request failed. default 1
188+
http_client: specify customized http_client
189+
Returns:
190+
async ark client
180191
"""
181192

182193
if ak is None:
@@ -191,7 +202,9 @@ def __init__(
191202
self.api_key = api_key
192203
self.region = region
193204

194-
assert (api_key is not None) or (ak is not None and sk is not None), "you need to support api_key or ak&sk"
205+
assert (api_key is not None) or (ak is not None and sk is not None), (
206+
"you need to support api_key or ak&sk"
207+
)
195208

196209
super().__init__(
197210
base_url=base_url,
@@ -227,10 +240,18 @@ def _get_endpoint_sts_token(self, endpoint_id: str):
227240
def _get_endpoint_certificate(self, endpoint_id: str) -> key_agreement_client:
228241
if self._certificate_manager is None:
229242
cert_path = os.environ.get("E2E_CERTIFICATE_PATH")
230-
if (self.ak is None or self.sk is None) and cert_path is None and self.api_key is None:
231-
raise ArkAPIError("must set (api_key) or (ak and sk) \
232-
or (E2E_CERTIFICATE_PATH) before get endpoint token.")
233-
self._certificate_manager = E2ECertificateManager(self.ak, self.sk, self.region, self._base_url, self.api_key)
243+
if (
244+
(self.ak is None or self.sk is None)
245+
and cert_path is None
246+
and self.api_key is None
247+
):
248+
raise ArkAPIError(
249+
"must set (api_key) or (ak and sk) \
250+
or (E2E_CERTIFICATE_PATH) before get endpoint token."
251+
)
252+
self._certificate_manager = E2ECertificateManager(
253+
self.ak, self.sk, self.region, self._base_url, self.api_key
254+
)
234255
return self._certificate_manager.get(endpoint_id)
235256

236257
@property
@@ -252,7 +273,9 @@ class StsTokenManager(object):
252273
_mandatory_refresh_timeout: int = _DEFAULT_MANDATORY_REFRESH_TIMEOUT
253274

254275
def __init__(self, ak: str, sk: str, region: str):
255-
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(lambda: ("", 0))
276+
self._endpoint_sts_tokens: Dict[str, Tuple[str, int]] = defaultdict(
277+
lambda: ("", 0)
278+
)
256279
self._refresh_lock = threading.Lock()
257280

258281
import volcenginesdkcore
@@ -272,10 +295,19 @@ def _need_refresh(self, ep: str, refresh_in: int | None = None) -> bool:
272295

273296
return self._endpoint_sts_tokens[ep][1] - time.time() < refresh_in
274297

275-
def _protected_refresh(self, ep: str, ttl: int = _DEFAULT_STS_TIMEOUT, is_mandatory: bool = False,
276-
resource_type: str = _DEFAULT_RESOURCE_TYPE):
298+
def _protected_refresh(
299+
self,
300+
ep: str,
301+
ttl: int = _DEFAULT_STS_TIMEOUT,
302+
is_mandatory: bool = False,
303+
resource_type: str = _DEFAULT_RESOURCE_TYPE,
304+
):
277305
if ttl < self._advisory_refresh_timeout * 2:
278-
raise ArkAPIError("ttl should not be under {} seconds.".format(self._advisory_refresh_timeout * 2))
306+
raise ArkAPIError(
307+
"ttl should not be under {} seconds.".format(
308+
self._advisory_refresh_timeout * 2
309+
)
310+
)
279311

280312
try:
281313
api_key, expired_time = self._load_api_key(
@@ -301,7 +333,9 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
301333
ep, self._mandatory_refresh_timeout
302334
)
303335

304-
self._protected_refresh(ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type)
336+
self._protected_refresh(
337+
ep, is_mandatory=is_mandatory_refresh, resource_type=resource_type
338+
)
305339
return
306340
finally:
307341
self._refresh_lock.release()
@@ -310,14 +344,20 @@ def _refresh(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE):
310344
if not self._need_refresh(ep, self._mandatory_refresh_timeout):
311345
return
312346

313-
self._protected_refresh(ep, is_mandatory=True, resource_type=resource_type)
347+
self._protected_refresh(
348+
ep, is_mandatory=True, resource_type=resource_type
349+
)
314350

315351
def get(self, ep: str, resource_type: str = _DEFAULT_RESOURCE_TYPE) -> str:
316352
self._refresh(ep, resource_type=resource_type)
317353
return self._endpoint_sts_tokens[ep][0]
318354

319-
def _load_api_key(self, ep: str, duration_seconds: int,
320-
resource_type: str = _DEFAULT_RESOURCE_TYPE) -> Tuple[str, int]:
355+
def _load_api_key(
356+
self,
357+
ep: str,
358+
duration_seconds: int,
359+
resource_type: str = _DEFAULT_RESOURCE_TYPE,
360+
) -> Tuple[str, int]:
321361
get_api_key_request = volcenginesdkark.GetApiKeyRequest(
322362
duration_seconds=duration_seconds,
323363
resource_type=resource_type,
@@ -331,19 +371,26 @@ def _load_api_key(self, ep: str, duration_seconds: int,
331371

332372

333373
class E2ECertificateManager(object):
334-
335-
class CertificateResponse():
374+
class CertificateResponse:
336375
Certificate: str
337376
"""The certificate content."""
338377

339-
def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL, api_key: str | None = None):
378+
def __init__(
379+
self,
380+
ak: str,
381+
sk: str,
382+
region: str,
383+
base_url: str | URL = BASE_URL,
384+
api_key: str | None = None,
385+
):
340386
self._certificate_manager: Dict[str, key_agreement_client] = {}
341387

342388
# local cache prepare
343389
self._init_local_cert_cache()
344390

345391
# api instance prepare
346392
import volcenginesdkcore
393+
347394
configuration = volcenginesdkcore.Configuration()
348395
configuration.ak = ak
349396
configuration.sk = sk
@@ -365,38 +412,47 @@ def __init__(self, ak: str, sk: str, region: str, base_url: str | URL = BASE_URL
365412
api_key=api_key,
366413
)
367414
self._e2e_uri = "/e2e/get/certificate"
368-
self._x_session_token = {'X-Session-Token': self._e2e_uri}
415+
self._x_session_token = {"X-Session-Token": self._e2e_uri}
369416

370417
def _load_cert_by_cert_path(self) -> str:
371-
with open(self.cert_path, 'r') as f:
418+
with open(self.cert_path, "r") as f:
372419
cert_pem = f.read()
373420
return cert_pem
374421

375422
def _load_cert_by_ak_sk(self, ep: str) -> str:
376-
get_endpoint_certificate_request = volcenginesdkark.GetEndpointCertificateRequest(
377-
id=ep
423+
get_endpoint_certificate_request = (
424+
volcenginesdkark.GetEndpointCertificateRequest(id=ep)
378425
)
379426
try:
380-
resp: volcenginesdkark.GetEndpointCertificateResponse = self.api_instance.get_endpoint_certificate(
381-
get_endpoint_certificate_request)
427+
resp: volcenginesdkark.GetEndpointCertificateResponse = (
428+
self.api_instance.get_endpoint_certificate(
429+
get_endpoint_certificate_request
430+
)
431+
)
382432
except ApiException as e:
383-
raise ArkAPIError("Getting model vendor encryption certificate failed: %s\n" % e)
433+
raise ArkAPIError(
434+
"Getting model vendor encryption certificate failed: %s\n" % e
435+
)
384436

385437
return resp.pca_instance_certificate
386438

387439
def _sync_load_cert_by_auth(self, ep: str) -> str:
388440
try: # try to make request with session header (used for header statistic)
389-
resp = self.client.post(self._e2e_uri, options={"headers": self._x_session_token},
390-
body={"model": ep}, cast_to=self.CertificateResponse)
441+
resp = self.client.post(
442+
self._e2e_uri,
443+
options={"headers": self._x_session_token},
444+
body={"model": ep},
445+
cast_to=self.CertificateResponse,
446+
)
391447
except Exception as e:
392448
raise ArkAPIError("Getting Certificate failed: %s\n" % e)
393-
if 'error' in resp:
394-
raise ArkAPIError("Getting Certificate failed: %s\n" % resp['error'])
395-
return resp['Certificate']
449+
if "error" in resp:
450+
raise ArkAPIError("Getting Certificate failed: %s\n" % resp["error"])
451+
return resp["Certificate"]
396452

397453
def _save_cert_to_file(self, ep: str, cert_pem: str):
398454
cert_file_path = os.path.join(self._cert_storage_path, f"{ep}.pem")
399-
with open(cert_file_path, 'w') as f:
455+
with open(cert_file_path, "w") as f:
400456
f.write(cert_pem)
401457

402458
def _load_cert_locally(self, ep: str) -> str | None:
@@ -406,7 +462,7 @@ def _load_cert_locally(self, ep: str) -> str | None:
406462
current_time = time.time()
407463
time_difference = current_time - last_modified_time
408464
if time_difference <= self._cert_expiration_seconds:
409-
with open(cert_file_path, 'r') as f:
465+
with open(cert_file_path, "r") as f:
410466
return f.read()
411467
else:
412468
os.remove(cert_file_path)

volcenginesdkarkruntime/_compat.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,7 @@
2424
def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001
2525
...
2626

27-
def parse_datetime(
28-
value: Union[datetime, StrBytesIntFloat]
29-
) -> datetime: # noqa: ARG001
27+
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001
3028
...
3129

3230
def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001
@@ -87,9 +85,7 @@ def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
8785
if PYDANTIC_V2:
8886
return model.model_validate(value)
8987
else:
90-
return cast(
91-
_ModelT, model.parse_obj(value)
92-
) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
88+
return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
9389

9490

9591
def field_is_required(field: FieldInfo) -> bool:

volcenginesdkarkruntime/_constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
DEFAULT_TIMEOUT_SECONDS = 600.0
1414
DEFAULT_CONNECT_TIMEOUT_SECONDS = 60.0
1515
# default timeout is 1 minutes
16-
DEFAULT_TIMEOUT = httpx.Timeout(timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS)
16+
DEFAULT_TIMEOUT = httpx.Timeout(
17+
timeout=DEFAULT_TIMEOUT_SECONDS, connect=DEFAULT_CONNECT_TIMEOUT_SECONDS
18+
)
1719

1820
DEFAULT_MAX_RETRIES = 2
1921
DEFAULT_CONNECTION_LIMITS = httpx.Limits(

volcenginesdkarkruntime/_exceptions.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Optional, cast, Dict
5+
from typing import Optional
66
from typing_extensions import Literal
77

88
import httpx
@@ -130,45 +130,31 @@ def __init__(self, request: httpx.Request, request_id: str) -> None:
130130

131131

132132
class ArkBadRequestError(ArkAPIStatusError):
133-
status_code: Literal[400] = (
134-
400 # pyright: ignore[reportIncompatibleVariableOverride]
135-
)
133+
status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
136134

137135

138136
class ArkAuthenticationError(ArkAPIStatusError):
139-
status_code: Literal[401] = (
140-
401 # pyright: ignore[reportIncompatibleVariableOverride]
141-
)
137+
status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
142138

143139

144140
class ArkPermissionDeniedError(ArkAPIStatusError):
145-
status_code: Literal[403] = (
146-
403 # pyright: ignore[reportIncompatibleVariableOverride]
147-
)
141+
status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
148142

149143

150144
class ArkNotFoundError(ArkAPIStatusError):
151-
status_code: Literal[404] = (
152-
404 # pyright: ignore[reportIncompatibleVariableOverride]
153-
)
145+
status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
154146

155147

156148
class ArkConflictError(ArkAPIStatusError):
157-
status_code: Literal[409] = (
158-
409 # pyright: ignore[reportIncompatibleVariableOverride]
159-
)
149+
status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
160150

161151

162152
class ArkUnprocessableEntityError(ArkAPIStatusError):
163-
status_code: Literal[422] = (
164-
422 # pyright: ignore[reportIncompatibleVariableOverride]
165-
)
153+
status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
166154

167155

168156
class ArkRateLimitError(ArkAPIStatusError):
169-
status_code: Literal[429] = (
170-
429 # pyright: ignore[reportIncompatibleVariableOverride]
171-
)
157+
status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
172158

173159

174160
class ArkInternalServerError(ArkAPIStatusError):

volcenginesdkarkruntime/_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def to_json(
143143
@override
144144
def __str__(self) -> str:
145145
# mypy complains about an invalid self arg
146-
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
146+
return f"{self.__repr_name__()}({self.__repr_str__(', ')})" # type: ignore[misc]
147147

148148
# Override the 'construct' method in a way that supports recursive parsing without validation.
149149
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.

volcenginesdkarkruntime/_request_options.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def construct( # type: ignore
5959
}
6060
if PYDANTIC_V2:
6161
return super().model_construct(_fields_set, **kwargs)
62-
return cast(
63-
RequestOptions, super().construct(_fields_set, **kwargs)
64-
) # pyright: ignore[reportDeprecated]
62+
return cast(RequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
6563

6664
if not TYPE_CHECKING:
6765
# type checkers incorrectly complain about this assignment

0 commit comments

Comments
 (0)