11import time
22from datetime import datetime
33from enum import Enum
4- from typing import Optional , Union , List , Iterable , Callable
4+ from typing import Optional , Union , List , Iterable , Callable , Literal
55from dataclasses import dataclass
66
77from gzip_stream import GZIPCompressedStream
@@ -36,9 +36,32 @@ def __init__(self, **kwargs):
3636 self .size : int = kwargs .get ('size' )
3737 self .progress : float = kwargs .get ('progress' )
3838
39+ class Glossary (LaraObject ):
40+ def __init__ (self , ** kwargs ):
41+ self .id : str = kwargs .get ('id' )
42+ self .name : str = kwargs .get ('name' )
43+ self .owner_id : str = kwargs .get ('owner_id' )
44+ self .created_at : datetime = self ._parse_date (kwargs .get ('created_at' , None ))
45+ self .updated_at : datetime = self ._parse_date (kwargs .get ('updated_at' , None ))
46+
47+ class GlossaryImport (LaraObject ):
48+ def __init__ (self , ** kwargs ):
49+ self .id : str = kwargs .get ('id' )
50+ self .begin : int = kwargs .get ('begin' )
51+ self .end : int = kwargs .get ('end' )
52+ self .channel : int = kwargs .get ('channel' )
53+ self .size : int = kwargs .get ('size' )
54+ self .progress : float = kwargs .get ('progress' )
55+
56+ class GlossaryCounts (LaraObject ):
57+ def __init__ (self , ** kwargs ):
58+ self .unidirectional : Optional [dict [str , int ]] = kwargs .get ('unidirectional' )
59+ self .multidirectional : Optional [int ] = kwargs .get ('multidirectional' )
60+
3961@dataclass
4062class DocumentOptions :
4163 adapt_to : Optional [List [str ]] = None
64+ glossaries : Optional [List [str ]] = None
4265 no_trace : Optional [bool ] = None
4366
4467class Document (LaraObject ):
@@ -168,6 +191,75 @@ def wait_for_import(self, memory_import: MemoryImport, *,
168191
169192 return memory_import
170193
194+ class Glossaries :
195+ def __init__ (self , client : LaraClient ):
196+ self ._client : LaraClient = client
197+ self ._polling_interval : int = 2
198+
199+ def list (self ) -> List [Glossary ]:
200+ return [Glossary (** e ) for e in self ._client .get ('/glossaries' )]
201+
202+ def create (self , name : str ) -> Glossary :
203+ return Glossary (** self ._client .post ('/glossaries' , {
204+ 'name' : name
205+ }))
206+
207+ def get (self , id_ : str ) -> Optional [Glossary ]:
208+ try :
209+ return Glossary (** self ._client .get (f'/glossaries/{ id_ } ' ))
210+ except LaraApiError as e :
211+ if e .status_code == 404 :
212+ return None
213+ raise
214+
215+ def delete (self , id_ : str ) -> Glossary :
216+ return Glossary (** self ._client .delete (f'/glossaries/{ id_ } ' ))
217+
218+ def update (self , id_ : str , name : str ) -> Glossary :
219+ return Glossary (** self ._client .put (f'/glossaries/{ id_ } ' , {
220+ 'name' : name
221+ }))
222+
223+ def import_csv (self , id_ : str , csv : str ) -> GlossaryImport :
224+ with open (csv , 'rb' ) as stream :
225+ compressed_stream = GZIPCompressedStream (stream , compression_level = 7 )
226+ return GlossaryImport (** self ._client .post (f'/glossaries/{ id_ } /import' ,
227+ {'compression' : 'gzip' }, {'csv' : compressed_stream }))
228+
229+ def get_import_status (self , id_ : str ) -> GlossaryImport :
230+ return GlossaryImport (** self ._client .get (f'/glossaries/imports/{ id_ } ' ))
231+
232+ def wait_for_import (self , glossary_import : GlossaryImport , * ,
233+ update_callback : Callable [[GlossaryImport ], None ] = None ,
234+ max_wait_time : float = 0 ) -> GlossaryImport :
235+ start = time .time ()
236+ while glossary_import .progress < 1. :
237+ if 0 < max_wait_time < time .time () - start :
238+ raise TimeoutError ()
239+
240+ time .sleep (self ._polling_interval )
241+
242+ glossary_import = self .get_import_status (glossary_import .id )
243+ if update_callback is not None :
244+ update_callback (glossary_import )
245+
246+ return glossary_import
247+
248+ def counts (self , id_ : str ) -> GlossaryCounts :
249+ return GlossaryCounts (** self ._client .get (f'/glossaries/{ id_ } /counts' ))
250+
251+
252+ def export (self , id_ : str , content_type : Literal ["csv/table-uni" ], source : Optional [str ]) -> bytes :
253+ """
254+ Exports a csv file with the glossary content. If the content_type is "csv/table-uni", the
255+ file will contain a unidirectional glossary with only terms in the specified source language (required)
256+ """
257+ response = self ._client .get (f'/glossaries/{ id_ } /export' , {
258+ 'content_type' : content_type ,
259+ 'source' : source
260+ })
261+ return response
262+
171263
172264
173265class DocumentStatus (Enum ):
@@ -186,7 +278,7 @@ def __init__(self, client: LaraClient):
186278 self ._polling_interval : int = 2
187279
188280 def upload (self , file_path : str , filename : str , target : str , source : Optional [str ] = None ,
189- adapt_to : Optional [List [str ]] = None , no_trace : bool = False ) -> Document :
281+ adapt_to : Optional [List [str ]] = None , glossaries : Optional [ List [ str ]] = None , no_trace : bool = False ) -> Document :
190282 with open (file_path , 'rb' ) as file_payload :
191283 response_data = self ._client .get ('/documents/upload-url' , {'filename' : filename })
192284
@@ -205,6 +297,9 @@ def upload(self, file_path: str, filename: str, target: str, source: Optional[st
205297 if adapt_to is not None :
206298 body ['adapt_to' ] = adapt_to
207299
300+ if glossaries is not None :
301+ body ['glossaries' ] = glossaries
302+
208303 headers = None
209304 if no_trace is True :
210305 headers = {'X-No-Trace' : 'true' }
@@ -222,11 +317,11 @@ def download(self, id: str, output_format: Optional[str] = None) -> bytes:
222317 return self ._s3client .download (url = url )
223318
224319 def translate (self , file_path : str , filename : str , target : str , source : Optional [str ] = None ,
225- adapt_to : Optional [List [str ]] = None , output_format : Optional [str ] = None ,
320+ adapt_to : Optional [List [str ]] = None , glossaries : Optional [ List [ str ]] = None , output_format : Optional [str ] = None ,
226321 no_trace : bool = False ) -> bytes :
227322
228323 document = self .upload (file_path = file_path , filename = filename , target = target , source = source , adapt_to = adapt_to ,
229- no_trace = no_trace )
324+ glossaries = glossaries , no_trace = no_trace )
230325
231326 max_wait_time = 60 * 15 # 15 minutes
232327 start = time .time ()
@@ -265,13 +360,14 @@ def __init__(self, credentials: Credentials = None, *,
265360 self ._client : LaraClient = LaraClient (credentials .access_key_id , credentials .access_key_secret , server_url )
266361 self .memories : Memories = Memories (self ._client )
267362 self .documents : Documents = Documents (self ._client )
363+ self .glossaries : Glossaries = Glossaries (self ._client )
268364
269365 def languages (self ) -> List [str ]:
270366 return self ._client .get ('/languages' )
271367
272368 def translate (self , text : Union [str , Iterable [str ], Iterable [TextBlock ]], * ,
273369 source : str = None , source_hint : str = None , target : str , adapt_to : List [str ] = None ,
274- instructions : List [str ] = None , content_type : str = None ,
370+ glossaries : List [ str ] = None , instructions : List [str ] = None , content_type : str = None ,
275371 multiline : bool = True , timeout_ms : int = None , priority : TranslatePriority = None ,
276372 use_cache : Union [bool , UseCache ] = None , cache_ttl_s : int = None ,
277373 no_trace : bool = False , verbose : bool = False ) -> TextResult :
@@ -297,7 +393,7 @@ def translate(self, text: Union[str, Iterable[str], Iterable[TextBlock]], *,
297393 'multiline' : multiline , 'adapt_to' : adapt_to , 'instructions' : instructions , 'timeout' : timeout_ms , 'q' : q ,
298394 'priority' : priority .value if priority is not None else None ,
299395 'use_cache' : use_cache .value if use_cache is not None else None , 'cache_ttl' : cache_ttl_s ,
300- 'verbose' : verbose
396+ 'glossaries' : glossaries , ' verbose' : verbose
301397 }
302398
303399 headers = None
0 commit comments