-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathtest_inference_parameters.py
More file actions
85 lines (68 loc) · 2.48 KB
/
test_inference_parameters.py
File metadata and controls
85 lines (68 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import pytest
from mindee import InferenceParameters
from mindee.input.inference_parameters import (
DataSchema,
DataSchemaReplace,
DataSchemaField,
)
from tests.utils import V2_PRODUCT_DATA_DIR
expected_data_schema_dict = json.loads(
(V2_PRODUCT_DATA_DIR / "extraction" / "data_schema_replace_param.json").read_text()
)
expected_data_schema_str = json.dumps(
expected_data_schema_dict, indent=None, sort_keys=True
)
def test_data_schema_replace_none():
params = InferenceParameters(model_id="test-id")
assert params.data_schema is None
def test_data_schema_replace_str():
params = InferenceParameters(
model_id="test-id", data_schema=expected_data_schema_str
)
assert str(params.data_schema) == expected_data_schema_str
def test_data_schema_replace_dict():
params = InferenceParameters(
model_id="test-id", data_schema=expected_data_schema_dict
)
assert str(params.data_schema) == expected_data_schema_str
def test_data_schema_replace_obj_top():
params = InferenceParameters(
model_id="test-id",
data_schema=DataSchema(replace=expected_data_schema_dict["replace"]),
)
assert str(params.data_schema) == expected_data_schema_str
def test_data_schema_replace_obj_fields():
params = InferenceParameters(
model_id="test-id",
data_schema=DataSchema(
replace=DataSchemaReplace(
fields=expected_data_schema_dict["replace"]["fields"]
)
),
)
assert str(params.data_schema) == expected_data_schema_str
def test_data_schema_replace_empty_fields():
with pytest.raises(
ValueError, match="Data schema replacement fields cannot be empty"
):
InferenceParameters(model_id="test-id", data_schema={"replace": {"fields": []}})
def test_data_schema_replace_obj_full():
params = InferenceParameters(
model_id="test-id",
data_schema=DataSchema(
replace=DataSchemaReplace(
fields=[
DataSchemaField(
name="test_replace",
title="Test Replace",
type="string",
is_array=False,
description="A static value for testing.",
guidelines="IMPORTANT: always return this exact string: 'a test value'",
)
]
)
),
)
assert str(params.data_schema) == expected_data_schema_str