Skip to content

Commit 9435f7a

Browse files
authored
Add document translation support
1 parent 9e74b45 commit 9435f7a

3 files changed

Lines changed: 122 additions & 2 deletions

File tree

src/lara_sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._client import LaraObject
22
from ._credentials import Credentials
33
from ._errors import LaraApiError, LaraError
4-
from ._translator import Memory, MemoryImport, TextBlock, TextResult, Memories, Translator, TranslatePriority, UseCache
4+
from ._translator import Memory, MemoryImport, TextBlock, TextResult, Memories, Translator, TranslatePriority, UseCache, Documents, Document, DocumentStatus
55

66
# This constant is auto-generated by the build script.
77
# Manual modifications will be overwritten and may cause unexpected behavior.

src/lara_sdk/_s3client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import IO, TypedDict
2+
import requests
3+
4+
class S3UploadFields(TypedDict):
5+
acl: str
6+
bucket: str
7+
key: str
8+
9+
10+
class S3Client():
11+
def __init__(self):
12+
self._session = requests.Session()
13+
14+
def upload(self, url: str, fields: S3UploadFields, file_payload: IO[bytes]) -> None:
15+
files_dict = {'file': file_payload}
16+
17+
data_fields = {key: str(value) for key, value in fields.items()}
18+
19+
try:
20+
response = self._session.post(url, data=data_fields, files=files_dict)
21+
22+
response.raise_for_status()
23+
except requests.RequestException as e:
24+
raise e
25+
26+
def download(self, url: str) -> bytes:
27+
try:
28+
response = self._session.get(url)
29+
response.raise_for_status()
30+
return response.content
31+
except requests.RequestException as e:
32+
raise e
33+

src/lara_sdk/_translator.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from datetime import datetime
33
from enum import Enum
44
from typing import Optional, Union, List, Iterable, Callable
5+
from dataclasses import dataclass
56

67
from gzip_stream import GZIPCompressedStream
78

89
from ._client import LaraObject, LaraClient
910
from ._credentials import Credentials
1011
from ._errors import LaraApiError
11-
12+
from ._s3client import S3Client, S3UploadFields
1213

1314
# Objects --------------------------------------------------------------------------------------------------------------
1415

@@ -35,6 +36,25 @@ def __init__(self, **kwargs):
3536
self.size: int = kwargs.get('size')
3637
self.progress: float = kwargs.get('progress')
3738

39+
@dataclass
40+
class DocumentOptions:
41+
adapt_to: Optional[List[str]] = None
42+
43+
class Document(LaraObject):
44+
45+
def __init__(self, **kwargs):
46+
self.id: str = kwargs.get('id')
47+
self.status: DocumentStatus = DocumentStatus(kwargs.get('status'))
48+
self.source: Optional[str] = kwargs.get('source')
49+
self.target: str = kwargs.get('target')
50+
self.filename: str = kwargs.get('filename')
51+
self.created_at: datetime = self._parse_date(kwargs.get('created_at'))
52+
self.updated_at: datetime = self._parse_date(kwargs.get('updated_at'))
53+
self.options: Optional[DocumentOptions] = DocumentOptions(**kwargs.get('options')) if kwargs.get('options') else None
54+
self.translated_chars: Optional[int] = int(kwargs.get('translated_chars')) if kwargs.get('translated_chars') else None
55+
self.total_chars: Optional[int] = int(kwargs.get('total_chars')) if kwargs.get('total_chars') else None
56+
self.error_reason: Optional[str] = kwargs.get('error_reason')
57+
3858

3959
class TextBlock(LaraObject):
4060
def __init__(self, **kwargs):
@@ -148,6 +168,72 @@ def wait_for_import(self, memory_import: MemoryImport, *,
148168
return memory_import
149169

150170

171+
172+
class DocumentStatus(Enum):
173+
INITIALIZED = 'initialized' # just been created
174+
ANALYZING = 'analyzing' # being analyzed for language detection and chars count
175+
PAUSED = 'paused' # paused after analysis, needs user confirm
176+
READY = 'ready' # ready to be translated
177+
TRANSLATING = 'translating'
178+
TRANSLATED = 'translated'
179+
ERROR = 'error'
180+
181+
class Documents:
182+
def __init__(self, client: LaraClient):
183+
self._client: LaraClient = client
184+
self._s3client = S3Client()
185+
self._polling_interval: int = 2
186+
187+
def upload(self, file_path: str, filename: str, target: str, source: Optional[str] = None, adapt_to: Optional[List[str]] = None) -> Document:
188+
with open(file_path, 'rb') as file_payload:
189+
response_data = self._client.get('/documents/upload-url', {'filename': filename})
190+
191+
url: str = response_data['url']
192+
fields: S3UploadFields = S3UploadFields(**response_data['fields'])
193+
194+
self._s3client.upload(url, fields, file_payload)
195+
196+
body = {
197+
's3key': fields['key'],
198+
'target': target,
199+
}
200+
if source is not None:
201+
body['source'] = source
202+
203+
if adapt_to is not None:
204+
body['adapt_to'] = adapt_to
205+
206+
return Document(**self._client.post('/documents', body))
207+
208+
def status(self, id: str) -> Document:
209+
return Document(**self._client.get(f'/documents/{id}'))
210+
211+
def download(self, id: str, output_format: Optional[str] = None) -> bytes:
212+
params = {}
213+
if output_format is not None:
214+
params['output_format'] = output_format
215+
url: str = self._client.get(f'/documents/{id}/download-url', params)['url']
216+
return self._s3client.download(url=url)
217+
218+
def translate(self, file_path: str, filename: str, target: str, source: Optional[str] = None,
219+
adapt_to: Optional[List[str]] = None, output_format: Optional[str] = None) -> Document:
220+
221+
document = self.upload(file_path=file_path, filename=filename, target=target, source=source, adapt_to=adapt_to)
222+
223+
max_wait_time = 60 * 15 # 15 minutes
224+
start = time.time()
225+
226+
while time.time() - start < max_wait_time:
227+
document = self.status(id=document.id)
228+
229+
if DocumentStatus(document.status) == DocumentStatus.TRANSLATED:
230+
return self.download(id=document.id, output_format=output_format)
231+
elif DocumentStatus(document.status) == DocumentStatus.ERROR:
232+
raise LaraApiError(500, "DocumentError", document.error_reason)
233+
234+
time.sleep(self._polling_interval)
235+
raise TimeoutError()
236+
151237
class TranslatePriority(Enum):
152238
NORMAL = 'normal'
153239
BACKGROUND = 'background'
@@ -170,6 +256,7 @@ def __init__(self, credentials: Credentials = None, *,
170256

171257
self._client: LaraClient = LaraClient(credentials.access_key_id, credentials.access_key_secret, server_url)
172258
self.memories: Memories = Memories(self._client)
259+
self.documents: Documents = Documents(self._client)
173260

174261
def languages(self) -> List[str]:
175262
return self._client.get('/languages')

0 commit comments

Comments
 (0)