Skip to content

Commit cbc9ef8

Browse files
feat(tagging): add retrocompatibility with text parameter
1 parent 76cf227 commit cbc9ef8

3 files changed

Lines changed: 74 additions & 6 deletions

File tree

hrflow/hrflow/text/tagging.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ def __init__(self, api):
1313
def post(
1414
self,
1515
algorithm_key: str,
16-
texts: t.List[str],
16+
text: t.Optional[str] = None,
17+
texts: t.Optional[t.List[str]] = None,
1718
context: t.Optional[str] = None,
1819
labels: t.Optional[t.List[str]] = None,
1920
top_n: t.Optional[int] = 1,
@@ -62,16 +63,22 @@ def post(
6263
Returns:
6364
`/text/tagging` response
6465
"""
65-
6666
payload = dict(
6767
algorithm_key=algorithm_key,
68-
texts=texts,
6968
context=context,
7069
labels=labels,
7170
output_lang=output_lang,
7271
top_n=top_n,
7372
)
74-
73+
74+
if texts is None and text is not None:
75+
payload["text"] = text
76+
elif text is None and texts is not None:
77+
payload["texts"] = texts
78+
elif text is None and texts is None:
79+
raise ValueError("Either text or texts must be provided.")
80+
else:
81+
raise ValueError("Only one of text or texts must be provided.")
82+
7583
response = self.client.post("text/tagging", json=payload)
76-
7784
return validate_response(response)

tests/test_text.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TextLinkingResponse,
1313
TextOCRResponse,
1414
TextParsingResponse,
15+
TextTaggingDataItem,
1516
TextTaggingReponse,
1617
)
1718
from .utils.tools import _file_get, _var_from_env_get
@@ -157,6 +158,63 @@ def test_linking_negative_amount(hrflow_client):
157158
assert model.code == requests.codes.bad_request
158159

159160

161+
@pytest.mark.text
162+
@pytest.mark.tagging
163+
def test_tagger_rome_family_with_text_param(hrflow_client):
164+
model = TextTaggingReponse.model_validate(
165+
hrflow_client.text.tagging.post(
166+
algorithm_key=TAGGING_ALGORITHM.TAGGER_ROME_FAMILY,
167+
text=TAGGING_TEXTS[0],
168+
top_n=2,
169+
)
170+
)
171+
assert model.code == requests.codes.ok
172+
assert isinstance(model.data, TextTaggingDataItem)
173+
174+
@pytest.mark.text
175+
@pytest.mark.tagging
176+
def test_tagger_rome_family_with_texts_param(hrflow_client):
177+
model = TextTaggingReponse.model_validate(
178+
hrflow_client.text.tagging.post(
179+
algorithm_key=TAGGING_ALGORITHM.TAGGER_ROME_FAMILY,
180+
texts=TAGGING_TEXTS,
181+
top_n=2,
182+
)
183+
)
184+
assert model.code == requests.codes.ok
185+
assert isinstance(model.data, list)
186+
assert len(model.data) == len(TAGGING_TEXTS)
187+
188+
@pytest.mark.text
189+
@pytest.mark.tagging
190+
def test_tagger_rome_family_with_text_and_texts_param(hrflow_client):
191+
try:
192+
TextTaggingReponse.model_validate(
193+
hrflow_client.text.tagging.post(
194+
algorithm_key=TAGGING_ALGORITHM.TAGGER_ROME_FAMILY,
195+
text=TAGGING_TEXTS[0],
196+
texts=TAGGING_TEXTS,
197+
top_n=2,
198+
)
199+
)
200+
pytest.fail("Should have raised a ValueError")
201+
except ValueError as e:
202+
pass
203+
204+
@pytest.mark.text
205+
@pytest.mark.tagging
206+
def test_tagger_rome_family_without_text_or_texts_param(hrflow_client):
207+
try:
208+
TextTaggingReponse.model_validate(
209+
hrflow_client.text.tagging.post(
210+
algorithm_key=TAGGING_ALGORITHM.TAGGER_ROME_FAMILY,
211+
top_n=2,
212+
)
213+
)
214+
pytest.fail("Should have raised a ValueError")
215+
except ValueError as e:
216+
pass
217+
160218
def _tagging_test(
161219
hrflow_client: Hrflow,
162220
algorithm_key: TAGGING_ALGORITHM,

tests/utils/schemas.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class TextTaggingDataItem(BaseModel):
100100
@model_validator(mode="before")
101101
@classmethod
102102
def _check(cls, values: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]]:
103+
if isinstance(values, list):
104+
return [cls._check(item) for item in values
105+
]
103106
li = len(values.get("ids"))
104107
lp = len(values.get("predictions"))
105108
lt = len(values.get("tags"))
@@ -111,7 +114,7 @@ def _check(cls, values: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]
111114

112115

113116
class TextTaggingReponse(HrFlowAPIResponse):
114-
data: t.Optional[t.List[TextTaggingDataItem]] = None
117+
data: t.Optional[t.Union[t.List[TextTaggingDataItem], TextTaggingDataItem]] = None
115118

116119

117120
class TextParsingDataItemEntity(BaseModel):

0 commit comments

Comments
 (0)