Skip to content

Commit 17c45a0

Browse files
author
liuhuiqi.7
committed
feat(crypto): fit batch encryption
1 parent e1aa308 commit 17c45a0

1 file changed

Lines changed: 5 additions & 146 deletions

File tree

Lines changed: 5 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
from __future__ import annotations
22

3-
import warnings
4-
from typing import Callable, Dict, Iterable, List, Optional, Union
3+
from typing import Dict, Iterable, List, Optional, Union
54

65
import httpx
76
from typing_extensions import Literal
87

98
from ...._base_client import make_request_options
109
from ...._compat import cached_property
11-
from ...._constants import ARK_E2E_ENCRYPTION_HEADER
1210
from ...._resource import AsyncAPIResource, SyncAPIResource
1311
from ...._types import Body, Headers, Query
14-
from ...._utils import async_with_sts_token, deepcopy_minimal, with_sts_token
15-
from ...._utils._key_agreement import aes_gcm_decrypt_base64_string
12+
from ...._utils import async_with_sts_token, with_sts_token
13+
from ...encryption import with_e2e_encryption, async_with_e2e_encryption
1614
from ....types.chat import (
1715
ChatCompletion,
1816
ChatCompletionMessageParam,
@@ -29,98 +27,13 @@
2927
__all__ = ["Completions", "AsyncCompletions"]
3028

3129

32-
def _process_messages(
33-
messages: Iterable[ChatCompletionMessageParam], f: Callable[[str], str]
34-
):
35-
for message in messages:
36-
if message.get("content", None) is not None:
37-
current_content = message.get("content")
38-
if isinstance(current_content, str):
39-
message["content"] = f(current_content)
40-
elif isinstance(current_content, Iterable):
41-
for part in current_content:
42-
if part.get("type", None) == "text":
43-
part["text"] = f(part["text"])
44-
elif part.get("type", None) == "image_url":
45-
if part["image_url"]["url"].startswith("data:"):
46-
part["image_url"]["url"] = f(part["image_url"]["url"])
47-
else:
48-
warnings.warn(
49-
"encryption is not supported for image url, "
50-
"please use base64 image if you want encryption"
51-
)
52-
else:
53-
raise TypeError(
54-
"encryption is not supported for content type {}".format(
55-
type(part)
56-
)
57-
)
58-
else:
59-
raise TypeError(
60-
"encryption is not supported for content type {}".format(
61-
type(message.get("content"))
62-
)
63-
)
64-
65-
6630
class Completions(SyncAPIResource):
6731
@cached_property
6832
def with_raw_response(self) -> CompletionsWithRawResponse:
6933
return CompletionsWithRawResponse(self)
7034

71-
def _process_messages(
72-
self, messages: Iterable[ChatCompletionMessageParam], f: Callable[[str], str]
73-
):
74-
for message in messages:
75-
if message.get("content", None) is not None:
76-
current_content = message.get("content")
77-
if isinstance(current_content, str):
78-
message["content"] = f(current_content)
79-
elif isinstance(current_content, Iterable):
80-
raise TypeError(
81-
"content type {} is not supported end-to-end encryption".format(
82-
type(message.get("content"))
83-
)
84-
)
85-
else:
86-
raise TypeError(
87-
"content type {} is not supported end-to-end encryption".format(
88-
type(message.get("content"))
89-
)
90-
)
91-
92-
def _encrypt(
93-
self,
94-
model: str,
95-
messages: Iterable[ChatCompletionMessageParam],
96-
extra_headers: Headers,
97-
) -> tuple[bytes, bytes]:
98-
client = self._client._get_endpoint_certificate(model)
99-
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
100-
extra_headers["X-Session-Token"] = session_token
101-
_process_messages(
102-
messages,
103-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
104-
)
105-
return _crypto_key, _crypto_nonce
106-
107-
def _decrypt(
108-
self, key: bytes, nonce: bytes, resp: ChatCompletion
109-
) -> ChatCompletion:
110-
if resp.choices is not None:
111-
for index, choice in enumerate(resp.choices):
112-
if (
113-
choice.message is not None
114-
and choice.finish_reason != "content_filter"
115-
and choice.message.content is not None
116-
):
117-
choice.message.content = aes_gcm_decrypt_base64_string(
118-
key, nonce, choice.message.content
119-
)
120-
resp.choices[index] = choice
121-
return resp
122-
12335
@with_sts_token
36+
@with_e2e_encryption
12437
def create(
12538
self,
12639
*,
@@ -151,15 +64,6 @@ def create(
15164
extra_body: Body | None = None,
15265
timeout: float | httpx.Timeout | None = None,
15366
) -> ChatCompletion:
154-
is_encrypt = False
155-
if (
156-
extra_headers is not None
157-
and extra_headers.get(ARK_E2E_ENCRYPTION_HEADER, None) == "true"
158-
):
159-
is_encrypt = True
160-
messages = deepcopy_minimal(messages)
161-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
162-
16367
deadline = get_request_last_time(self._client, timeout)
16468
breaker = self._client.get_model_breaker(model)
16569

@@ -200,10 +104,6 @@ def create(
200104
),
201105
cast_to=ChatCompletion,
202106
)
203-
204-
if is_encrypt:
205-
resp = self._decrypt(e2e_key, e2e_nonce, resp)
206-
return resp
207107
return resp
208108

209109

@@ -212,38 +112,8 @@ class AsyncCompletions(AsyncAPIResource):
212112
def with_raw_response(self) -> AsyncCompletionsWithRawResponse:
213113
return AsyncCompletionsWithRawResponse(self)
214114

215-
def _encrypt(
216-
self,
217-
model: str,
218-
messages: Iterable[ChatCompletionMessageParam],
219-
extra_headers: Headers,
220-
) -> tuple[bytes, bytes]:
221-
client = self._client._get_endpoint_certificate(model)
222-
_crypto_key, _crypto_nonce, session_token = client.generate_ecies_key_pair()
223-
extra_headers["X-Session-Token"] = session_token
224-
_process_messages(
225-
messages,
226-
lambda x: client.encrypt_string_with_key(_crypto_key, _crypto_nonce, x),
227-
)
228-
return _crypto_key, _crypto_nonce
229-
230-
async def _decrypt(
231-
self, key: bytes, nonce: bytes, resp: ChatCompletion
232-
) -> ChatCompletion:
233-
if resp.choices is not None:
234-
for index, choice in enumerate(resp.choices):
235-
if (
236-
choice.message is not None
237-
and choice.finish_reason != "content_filter"
238-
and choice.message.content is not None
239-
):
240-
choice.message.content = aes_gcm_decrypt_base64_string(
241-
key, nonce, choice.message.content
242-
)
243-
resp.choices[index] = choice
244-
return resp
245-
246115
@async_with_sts_token
116+
@async_with_e2e_encryption
247117
async def create(
248118
self,
249119
*,
@@ -274,15 +144,6 @@ async def create(
274144
extra_body: Body | None = None,
275145
timeout: float | httpx.Timeout | None = None,
276146
) -> ChatCompletion:
277-
is_encrypt = False
278-
if (
279-
extra_headers is not None
280-
and extra_headers.get(ARK_E2E_ENCRYPTION_HEADER, None) == "true"
281-
):
282-
is_encrypt = True
283-
messages = deepcopy_minimal(messages)
284-
e2e_key, e2e_nonce = self._encrypt(model, messages, extra_headers)
285-
286147
deadline = get_request_last_time(self._client, timeout)
287148
breaker = await self._client.get_model_breaker(model)
288149

@@ -324,6 +185,4 @@ async def create(
324185
cast_to=ChatCompletion,
325186
)
326187

327-
if is_encrypt:
328-
resp = await self._decrypt(e2e_key, e2e_nonce, resp)
329188
return resp

0 commit comments

Comments
 (0)