Skip to content

Commit fa3e622

Browse files
authored
Add DetectPrediction class and predictions to DetectResult
1 parent fbf1c97 commit fa3e622

2 files changed

Lines changed: 8 additions & 2 deletions

File tree

src/lara_sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from ._client import LaraObject
22
from ._credentials import Credentials
33
from ._errors import LaraApiError, LaraError
4-
from ._translator import Memory, MemoryImport, TextBlock, TextResult, DetectResult, Memories, Translator, TranslatePriority, UseCache, Documents, Document, DocumentStatus, DocxExtractionParams, DocumentExtractionParams
4+
from ._translator import Memory, MemoryImport, TextBlock, TextResult, DetectPrediction, DetectResult, Memories, Translator, TranslatePriority, UseCache, Documents, Document, DocumentStatus, DocxExtractionParams, DocumentExtractionParams
55

66
# This constant is auto-generated by the build script.
77
# Manual modifications will be overwritten and may cause unexpected behavior.

src/lara_sdk/_translator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,16 @@ def __init__(self, **kwargs):
141141
self.translation = [TextBlock(**e) for e in translation]
142142

143143

144+
class DetectPrediction(LaraObject):
145+
def __init__(self, **kwargs):
146+
self.language: str = kwargs.get('language')
147+
self.confidence: float = kwargs.get('confidence')
148+
144149
class DetectResult(LaraObject):
145150
def __init__(self, **kwargs):
146151
self.language: str = kwargs.get('language')
147152
self.content_type: str = kwargs.get('content_type')
153+
self.predictions: List[DetectPrediction] = [DetectPrediction(**p) for p in kwargs.get('predictions', [])]
148154

149155

150156
# Translator SDK -------------------------------------------------------------------------------------------------------
@@ -480,5 +486,5 @@ def detect(self, text: Union[str, List[str]], *, hint: Optional[str] = None,
480486
'hint': hint,
481487
'passlist': passlist
482488
}
483-
489+
484490
return DetectResult(**self._client.post('/detect', body))

0 commit comments

Comments
 (0)