11from __future__ import annotations
22
3- import warnings
4- from typing import Callable , Dict , Iterable , List , Optional , Union
3+ from typing import Dict , Iterable , List , Optional , Union
54
65import httpx
76from typing_extensions import Literal
87
98from ...._base_client import make_request_options
109from ...._compat import cached_property
11- from ...._constants import ARK_E2E_ENCRYPTION_HEADER
1210from ...._resource import AsyncAPIResource , SyncAPIResource
1311from ...._types import Body , Headers , Query
14- from ...._utils import async_with_sts_token , deepcopy_minimal , with_sts_token
15- from .... _utils . _key_agreement import aes_gcm_decrypt_base64_string
12+ from ...._utils import async_with_sts_token , with_sts_token
13+ from ...encryption import with_e2e_encryption , async_with_e2e_encryption
1614from ....types .chat import (
1715 ChatCompletion ,
1816 ChatCompletionMessageParam ,
2927__all__ = ["Completions" , "AsyncCompletions" ]
3028
3129
32- def _process_messages (
33- messages : Iterable [ChatCompletionMessageParam ], f : Callable [[str ], str ]
34- ):
35- for message in messages :
36- if message .get ("content" , None ) is not None :
37- current_content = message .get ("content" )
38- if isinstance (current_content , str ):
39- message ["content" ] = f (current_content )
40- elif isinstance (current_content , Iterable ):
41- for part in current_content :
42- if part .get ("type" , None ) == "text" :
43- part ["text" ] = f (part ["text" ])
44- elif part .get ("type" , None ) == "image_url" :
45- if part ["image_url" ]["url" ].startswith ("data:" ):
46- part ["image_url" ]["url" ] = f (part ["image_url" ]["url" ])
47- else :
48- warnings .warn (
49- "encryption is not supported for image url, "
50- "please use base64 image if you want encryption"
51- )
52- else :
53- raise TypeError (
54- "encryption is not supported for content type {}" .format (
55- type (part )
56- )
57- )
58- else :
59- raise TypeError (
60- "encryption is not supported for content type {}" .format (
61- type (message .get ("content" ))
62- )
63- )
64-
65-
6630class Completions (SyncAPIResource ):
6731 @cached_property
6832 def with_raw_response (self ) -> CompletionsWithRawResponse :
6933 return CompletionsWithRawResponse (self )
7034
71- def _process_messages (
72- self , messages : Iterable [ChatCompletionMessageParam ], f : Callable [[str ], str ]
73- ):
74- for message in messages :
75- if message .get ("content" , None ) is not None :
76- current_content = message .get ("content" )
77- if isinstance (current_content , str ):
78- message ["content" ] = f (current_content )
79- elif isinstance (current_content , Iterable ):
80- raise TypeError (
81- "content type {} is not supported end-to-end encryption" .format (
82- type (message .get ("content" ))
83- )
84- )
85- else :
86- raise TypeError (
87- "content type {} is not supported end-to-end encryption" .format (
88- type (message .get ("content" ))
89- )
90- )
91-
92- def _encrypt (
93- self ,
94- model : str ,
95- messages : Iterable [ChatCompletionMessageParam ],
96- extra_headers : Headers ,
97- ) -> tuple [bytes , bytes ]:
98- client = self ._client ._get_endpoint_certificate (model )
99- _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
100- extra_headers ["X-Session-Token" ] = session_token
101- _process_messages (
102- messages ,
103- lambda x : client .encrypt_string_with_key (_crypto_key , _crypto_nonce , x ),
104- )
105- return _crypto_key , _crypto_nonce
106-
107- def _decrypt (
108- self , key : bytes , nonce : bytes , resp : ChatCompletion
109- ) -> ChatCompletion :
110- if resp .choices is not None :
111- for index , choice in enumerate (resp .choices ):
112- if (
113- choice .message is not None
114- and choice .finish_reason != "content_filter"
115- and choice .message .content is not None
116- ):
117- choice .message .content = aes_gcm_decrypt_base64_string (
118- key , nonce , choice .message .content
119- )
120- resp .choices [index ] = choice
121- return resp
122-
12335 @with_sts_token
36+ @with_e2e_encryption
12437 def create (
12538 self ,
12639 * ,
@@ -151,15 +64,6 @@ def create(
15164 extra_body : Body | None = None ,
15265 timeout : float | httpx .Timeout | None = None ,
15366 ) -> ChatCompletion :
154- is_encrypt = False
155- if (
156- extra_headers is not None
157- and extra_headers .get (ARK_E2E_ENCRYPTION_HEADER , None ) == "true"
158- ):
159- is_encrypt = True
160- messages = deepcopy_minimal (messages )
161- e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
162-
16367 deadline = get_request_last_time (self ._client , timeout )
16468 breaker = self ._client .get_model_breaker (model )
16569
@@ -200,10 +104,6 @@ def create(
200104 ),
201105 cast_to = ChatCompletion ,
202106 )
203-
204- if is_encrypt :
205- resp = self ._decrypt (e2e_key , e2e_nonce , resp )
206- return resp
207107 return resp
208108
209109
@@ -212,38 +112,8 @@ class AsyncCompletions(AsyncAPIResource):
212112 def with_raw_response (self ) -> AsyncCompletionsWithRawResponse :
213113 return AsyncCompletionsWithRawResponse (self )
214114
215- def _encrypt (
216- self ,
217- model : str ,
218- messages : Iterable [ChatCompletionMessageParam ],
219- extra_headers : Headers ,
220- ) -> tuple [bytes , bytes ]:
221- client = self ._client ._get_endpoint_certificate (model )
222- _crypto_key , _crypto_nonce , session_token = client .generate_ecies_key_pair ()
223- extra_headers ["X-Session-Token" ] = session_token
224- _process_messages (
225- messages ,
226- lambda x : client .encrypt_string_with_key (_crypto_key , _crypto_nonce , x ),
227- )
228- return _crypto_key , _crypto_nonce
229-
230- async def _decrypt (
231- self , key : bytes , nonce : bytes , resp : ChatCompletion
232- ) -> ChatCompletion :
233- if resp .choices is not None :
234- for index , choice in enumerate (resp .choices ):
235- if (
236- choice .message is not None
237- and choice .finish_reason != "content_filter"
238- and choice .message .content is not None
239- ):
240- choice .message .content = aes_gcm_decrypt_base64_string (
241- key , nonce , choice .message .content
242- )
243- resp .choices [index ] = choice
244- return resp
245-
246115 @async_with_sts_token
116+ @async_with_e2e_encryption
247117 async def create (
248118 self ,
249119 * ,
@@ -274,15 +144,6 @@ async def create(
274144 extra_body : Body | None = None ,
275145 timeout : float | httpx .Timeout | None = None ,
276146 ) -> ChatCompletion :
277- is_encrypt = False
278- if (
279- extra_headers is not None
280- and extra_headers .get (ARK_E2E_ENCRYPTION_HEADER , None ) == "true"
281- ):
282- is_encrypt = True
283- messages = deepcopy_minimal (messages )
284- e2e_key , e2e_nonce = self ._encrypt (model , messages , extra_headers )
285-
286147 deadline = get_request_last_time (self ._client , timeout )
287148 breaker = await self ._client .get_model_breaker (model )
288149
@@ -324,6 +185,4 @@ async def create(
324185 cast_to = ChatCompletion ,
325186 )
326187
327- if is_encrypt :
328- resp = await self ._decrypt (e2e_key , e2e_nonce , resp )
329188 return resp
0 commit comments