Skip to content

Commit 9fab4b6

Browse files
authored
Add Glossary support
- Introduced Glossary, GlossaryImport, and GlossaryCounts models, plus Glossaries service for CRUD, import/export, and counts - Extended Documents, high-level client, and LaraClient to support glossaries parameter and binary CSV export - Updated LaraClient._request and HTTP methods to return raw bytes for CSV/export endpoints
1 parent f833ae3 commit 9fab4b6

2 files changed

Lines changed: 109 additions & 11 deletions

File tree

src/lara_sdk/_client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, access_key_id: str, access_key_secret: str, server_url: str =
7575
self.sdk_name: str = 'lara-python'
7676
self.sdk_version: str = __import__('lara_sdk').__version__
7777

78-
def get(self, path: str, params: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List]]:
78+
def get(self, path: str, params: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
7979
"""
8080
Sends a GET request to the Lara API.
8181
:param path: The path to send the request to.
@@ -85,7 +85,7 @@ def get(self, path: str, params: Dict = None, headers: Dict = None) -> Optional[
8585
"""
8686
return self._request('GET', path, body=params, headers=headers)
8787

88-
def delete(self, path: str, params: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List]]:
88+
def delete(self, path: str, params: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
8989
"""
9090
Sends a DELETE request to the Lara API.
9191
:param path: The path to send the request to.
@@ -95,7 +95,7 @@ def delete(self, path: str, params: Dict = None, headers: Dict = None) -> Option
9595
"""
9696
return self._request('DELETE', path, body=params, headers=headers)
9797

98-
def post(self, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List]]:
98+
def post(self, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
9999
"""
100100
Sends a POST request to the Lara API.
101101
:param path: The path to send the request to.
@@ -106,7 +106,7 @@ def post(self, path: str, body: Dict = None, files: Dict = None, headers: Dict =
106106
"""
107107
return self._request('POST', path, body, files, headers)
108108

109-
def put(self, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List]]:
109+
def put(self, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
110110
"""
111111
Sends a PUT request to the Lara API.
112112
:param path: The path to send the request to.
@@ -117,7 +117,7 @@ def put(self, path: str, body: Dict = None, files: Dict = None, headers: Dict =
117117
"""
118118
return self._request('PUT', path, body, files, headers)
119119

120-
def _request(self, method: str, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List]]:
120+
def _request(self, method: str, path: str, body: Dict = None, files: Dict = None, headers: Dict = None) -> Optional[Union[Dict, List, bytes]]:
121121
if not path.startswith('/'):
122122
path = '/' + path
123123

@@ -145,5 +145,7 @@ def _request(self, method: str, path: str, body: Dict = None, files: Dict = None
145145
response = self.session.request('POST', f'{self.base_url}{path}', headers=_headers, json=body)
146146

147147
if 200 <= response.status_code < 300:
148+
if "text/csv" in response.headers.get('Content-Type', ''):
149+
return response.content
148150
return response.json().get('content', None)
149151
raise LaraApiError.from_response(response)

src/lara_sdk/_translator.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import time
22
from datetime import datetime
33
from enum import Enum
4-
from typing import Optional, Union, List, Iterable, Callable
4+
from typing import Optional, Union, List, Iterable, Callable, Literal
55
from dataclasses import dataclass
66

77
from gzip_stream import GZIPCompressedStream
@@ -36,9 +36,32 @@ def __init__(self, **kwargs):
3636
self.size: int = kwargs.get('size')
3737
self.progress: float = kwargs.get('progress')
3838

39+
class Glossary(LaraObject):
40+
def __init__(self, **kwargs):
41+
self.id: str = kwargs.get('id')
42+
self.name: str = kwargs.get('name')
43+
self.owner_id: str = kwargs.get('owner_id')
44+
self.created_at: datetime = self._parse_date(kwargs.get('created_at', None))
45+
self.updated_at: datetime = self._parse_date(kwargs.get('updated_at', None))
46+
47+
class GlossaryImport(LaraObject):
48+
def __init__(self, **kwargs):
49+
self.id: str = kwargs.get('id')
50+
self.begin: int = kwargs.get('begin')
51+
self.end: int = kwargs.get('end')
52+
self.channel: int = kwargs.get('channel')
53+
self.size: int = kwargs.get('size')
54+
self.progress: float = kwargs.get('progress')
55+
56+
class GlossaryCounts(LaraObject):
57+
def __init__(self, **kwargs):
58+
self.unidirectional: Optional[dict[str, int]] = kwargs.get('unidirectional')
59+
self.multidirectional: Optional[int] = kwargs.get('multidirectional')
60+
3961
@dataclass
4062
class DocumentOptions:
4163
adapt_to: Optional[List[str]] = None
64+
glossaries: Optional[List[str]] = None
4265
no_trace: Optional[bool] = None
4366

4467
class Document(LaraObject):
@@ -168,6 +191,75 @@ def wait_for_import(self, memory_import: MemoryImport, *,
168191

169192
return memory_import
170193

194+
class Glossaries:
195+
def __init__(self, client: LaraClient):
196+
self._client: LaraClient = client
197+
self._polling_interval: int = 2
198+
199+
def list(self) -> List[Glossary]:
200+
return [Glossary(**e) for e in self._client.get('/glossaries')]
201+
202+
def create(self, name: str) -> Glossary:
203+
return Glossary(**self._client.post('/glossaries', {
204+
'name': name
205+
}))
206+
207+
def get(self, id_: str) -> Optional[Glossary]:
208+
try:
209+
return Glossary(**self._client.get(f'/glossaries/{id_}'))
210+
except LaraApiError as e:
211+
if e.status_code == 404:
212+
return None
213+
raise
214+
215+
def delete(self, id_: str) -> Glossary:
216+
return Glossary(**self._client.delete(f'/glossaries/{id_}'))
217+
218+
def update(self, id_: str, name: str) -> Glossary:
219+
return Glossary(**self._client.put(f'/glossaries/{id_}', {
220+
'name': name
221+
}))
222+
223+
def import_csv(self, id_: str, csv: str) -> GlossaryImport:
224+
with open(csv, 'rb') as stream:
225+
compressed_stream = GZIPCompressedStream(stream, compression_level=7)
226+
return GlossaryImport(**self._client.post(f'/glossaries/{id_}/import',
227+
{'compression': 'gzip'}, {'csv': compressed_stream}))
228+
229+
def get_import_status(self, id_: str) -> GlossaryImport:
230+
return GlossaryImport(**self._client.get(f'/glossaries/imports/{id_}'))
231+
232+
def wait_for_import(self, glossary_import: GlossaryImport, *,
233+
update_callback: Callable[[GlossaryImport], None] = None,
234+
max_wait_time: float = 0) -> GlossaryImport:
235+
start = time.time()
236+
while glossary_import.progress < 1.:
237+
if 0 < max_wait_time < time.time() - start:
238+
raise TimeoutError()
239+
240+
time.sleep(self._polling_interval)
241+
242+
glossary_import = self.get_import_status(glossary_import.id)
243+
if update_callback is not None:
244+
update_callback(glossary_import)
245+
246+
return glossary_import
247+
248+
def counts(self, id_: str) -> GlossaryCounts:
249+
return GlossaryCounts(**self._client.get(f'/glossaries/{id_}/counts'))
250+
251+
252+
def export(self, id_: str, content_type: Literal["csv/table-uni"], source: Optional[str]) -> bytes:
253+
"""
254+
Exports a csv file with the glossary content. If the content_type is "csv/table-uni", the
255+
file will contain a unidirectional glossary with only terms in the specified source language (required)
256+
"""
257+
response = self._client.get(f'/glossaries/{id_}/export', {
258+
'content_type': content_type,
259+
'source': source
260+
})
261+
return response
262+
171263

172264

173265
class DocumentStatus(Enum):
@@ -186,7 +278,7 @@ def __init__(self, client: LaraClient):
186278
self._polling_interval: int = 2
187279

188280
def upload(self, file_path: str, filename: str, target: str, source: Optional[str] = None,
189-
adapt_to: Optional[List[str]] = None, no_trace: bool = False) -> Document:
281+
adapt_to: Optional[List[str]] = None, glossaries: Optional[List[str]] = None, no_trace: bool = False) -> Document:
190282
with open(file_path, 'rb') as file_payload:
191283
response_data = self._client.get('/documents/upload-url', {'filename': filename})
192284

@@ -205,6 +297,9 @@ def upload(self, file_path: str, filename: str, target: str, source: Optional[st
205297
if adapt_to is not None:
206298
body['adapt_to'] = adapt_to
207299

300+
if glossaries is not None:
301+
body['glossaries'] = glossaries
302+
208303
headers = None
209304
if no_trace is True:
210305
headers = {'X-No-Trace': 'true'}
@@ -222,11 +317,11 @@ def download(self, id: str, output_format: Optional[str] = None) -> bytes:
222317
return self._s3client.download(url=url)
223318

224319
def translate(self, file_path: str, filename: str, target: str, source: Optional[str] = None,
225-
adapt_to: Optional[List[str]] = None, output_format: Optional[str] = None,
320+
adapt_to: Optional[List[str]] = None, glossaries: Optional[List[str]] = None, output_format: Optional[str] = None,
226321
no_trace: bool = False) -> bytes:
227322

228323
document = self.upload(file_path=file_path, filename=filename, target=target, source=source, adapt_to=adapt_to,
229-
no_trace=no_trace)
324+
glossaries=glossaries, no_trace=no_trace)
230325

231326
max_wait_time = 60 * 15 # 15 minutes
232327
start = time.time()
@@ -265,13 +360,14 @@ def __init__(self, credentials: Credentials = None, *,
265360
self._client: LaraClient = LaraClient(credentials.access_key_id, credentials.access_key_secret, server_url)
266361
self.memories: Memories = Memories(self._client)
267362
self.documents: Documents = Documents(self._client)
363+
self.glossaries: Glossaries = Glossaries(self._client)
268364

269365
def languages(self) -> List[str]:
270366
return self._client.get('/languages')
271367

272368
def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
273369
source: str = None, source_hint: str = None, target: str, adapt_to: List[str] = None,
274-
instructions: List[str] = None, content_type: str = None,
370+
glossaries: List[str] = None, instructions: List[str] = None, content_type: str = None,
275371
multiline: bool = True, timeout_ms: int = None, priority: TranslatePriority = None,
276372
use_cache: Union[bool, UseCache] = None, cache_ttl_s: int = None,
277373
no_trace: bool = False, verbose: bool = False) -> TextResult:
@@ -297,7 +393,7 @@ def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
297393
'multiline': multiline, 'adapt_to': adapt_to, 'instructions': instructions, 'timeout': timeout_ms, 'q': q,
298394
'priority': priority.value if priority is not None else None,
299395
'use_cache': use_cache.value if use_cache is not None else None, 'cache_ttl': cache_ttl_s,
300-
'verbose': verbose
396+
'glossaries': glossaries, 'verbose': verbose
301397
}
302398

303399
headers = None

0 commit comments

Comments
 (0)