Skip to content

Commit 9a91a9c

Browse files
Add reasoning support
1 parent e6a73dd commit 9a91a9c

2 files changed

Lines changed: 73 additions & 4 deletions

File tree

src/lara_sdk/_client.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,17 @@ def put(self, path: str, body: Dict = None, files: Dict = None, headers: Dict =
117117
"""
118118
return self._request('PUT', path, body, files, headers)
119119

120+
def post_and_get_stream(self, path: str, body: Dict = None, files: Dict = None, headers: Dict = None):
121+
"""
122+
Sends a POST request to the Lara API and yields streaming responses.
123+
:param path: The path to send the request to.
124+
:param body: The parameters to send with the request.
125+
:param files: The files to send with the request. If present, request will be sent as multipart/form-data.
126+
:param headers: Additional headers to include in the request.
127+
:return: A generator yielding streaming responses.
128+
"""
129+
return self._request_stream('POST', path, body, files, headers)
130+
120131
def _request(self, method: str, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
121132
if not path.startswith('/'):
122133
path = '/' + path
@@ -148,4 +159,52 @@ def _request(self, method: str, path: str, body: Dict = None, files: Dict = None
148159
if "text/csv" in response.headers.get('Content-Type', ''):
149160
return response.content
150161
return response.json().get('content', None)
151-
raise LaraApiError.from_response(response)
162+
raise LaraApiError.from_response(response)
163+
164+
def _request_stream(self, method: str, path: str, body: Dict = None, files: Dict = None, headers: Dict = None):
165+
if not path.startswith('/'):
166+
path = '/' + path
167+
168+
_headers = {
169+
'X-HTTP-Method-Override': method,
170+
'Date': datetime.datetime.now(datetime.timezone.utc).isoformat(),
171+
'X-Lara-SDK-Name': self.sdk_name,
172+
'X-Lara-SDK-Version': self.sdk_version
173+
}
174+
175+
if headers is not None:
176+
_headers.update(headers)
177+
178+
if body is not None:
179+
body = {k: v for k, v in body.items() if v is not None}
180+
181+
if len(body) > 0:
182+
encoded_body = json.dumps(body, ensure_ascii=False, separators=(',', ':')).encode('UTF-8')
183+
_headers['Content-MD5'] = hashlib.md5(encoded_body).hexdigest()
184+
185+
if files is not None:
186+
response = self.session.request('POST', f'{self.base_url}{path}', headers=_headers, data=body, files=files, stream=True)
187+
else:
188+
response = self.session.request('POST', f'{self.base_url}{path}', headers=_headers, json=body, stream=True)
189+
190+
buffer = ''
191+
for chunk in response.iter_content(chunk_size=None, decode_unicode=True):
192+
if chunk:
193+
buffer += chunk
194+
lines = buffer.split('\n')
195+
buffer = lines.pop()
196+
197+
for line in lines:
198+
if line.strip():
199+
try:
200+
parsed = json.loads(line)
201+
yield parsed.get('data', parsed).get('content')
202+
except (json.JSONDecodeError, AttributeError):
203+
pass
204+
205+
if buffer.strip():
206+
try:
207+
parsed = json.loads(buffer)
208+
yield parsed.get('data', parsed).get('content')
209+
except (json.JSONDecodeError, AttributeError):
210+
pass

src/lara_sdk/_translator.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,8 @@ def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
429429
multiline: bool = True, timeout_ms: int = None, priority: TranslatePriority = None,
430430
use_cache: Union[bool, UseCache] = None, cache_ttl_s: int = None,
431431
no_trace: bool = False, verbose: bool = False, style: Optional[TranslationStyle] = None,
432-
headers: Optional[Dict[str, str]] = None) -> TextResult:
432+
headers: Optional[Dict[str, str]] = None, reasoning: bool = False,
433+
callback: Optional[Callable[[TextResult], None]] = None) -> TextResult:
433434
if isinstance(text, str):
434435
q = text
435436
elif hasattr(text, '__iter__'):
@@ -452,7 +453,7 @@ def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
452453
'multiline': multiline, 'adapt_to': adapt_to, 'instructions': instructions, 'timeout': timeout_ms, 'q': q,
453454
'priority': priority.value if priority is not None else None,
454455
'use_cache': use_cache.value if use_cache is not None else None, 'cache_ttl': cache_ttl_s,
455-
'glossaries': glossaries, 'verbose': verbose, 'style': style
456+
'glossaries': glossaries, 'verbose': verbose, 'style': style, 'reasoning': reasoning
456457
}
457458

458459
request_headers = {}
@@ -461,7 +462,16 @@ def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
461462
if no_trace is True:
462463
request_headers['X-No-Trace'] = 'true'
463464

464-
return TextResult(**self._client.post('/translate', body, headers=request_headers))
465+
last_result = None
466+
for partial in self._client.post_and_get_stream('/translate', body, headers=request_headers):
467+
last_result = partial
468+
if callback is not None and reasoning:
469+
callback(TextResult(**partial))
470+
471+
if last_result is None:
472+
raise ValueError('No translation result received.')
473+
474+
return TextResult(**last_result)
465475

466476
def detect(self, text: Union[str, List[str]], *, hint: Optional[str] = None,
467477
passlist: Optional[List[str]] = None) -> DetectResult:

0 commit comments

Comments
 (0)