|
1 | | -from unittest.mock import Mock, patch |
| 1 | +""" |
| 2 | +Tests for PrimingGroup and ML pipeline methods on ExperimentalApi. |
| 3 | +
|
| 4 | +The create/get/delete priming group tests are marked @pytest.mark.expensive because they |
| 5 | +require training a detector (submit images, wait ~45 s) before a pipeline has a |
| 6 | +cached_vizlogic_key that can seed a PrimingGroup. Run them explicitly with: |
| 7 | +
|
| 8 | + pytest -m expensive test/unit/test_priming_groups.py |
| 9 | +""" |
| 10 | + |
| 11 | +import time |
2 | 12 |
|
3 | 13 | import pytest |
4 | 14 | from groundlight import ExperimentalApi |
5 | 15 | from groundlight.internalapi import NotFoundError |
6 | 16 | from model import MLPipeline, PrimingGroup |
7 | 17 |
|
8 | | -# --------------------------------------------------------------------------- |
9 | | -# Helpers |
10 | | -# --------------------------------------------------------------------------- |
11 | | - |
12 | | -MOCK_PIPELINE_DATA = { |
13 | | - "id": "pipe_abc123", |
14 | | - "pipeline_config": "never-review", |
15 | | - "cached_vizlogic_key": "mlb_testkey1234", |
16 | | - "is_active_pipeline": True, |
17 | | - "is_edge_pipeline": False, |
18 | | - "is_unclear_pipeline": False, |
19 | | - "is_oodd_pipeline": False, |
20 | | - "is_enabled": True, |
21 | | - "created_at": "2026-01-01T00:00:00Z", |
22 | | - "trained_at": None, |
23 | | -} |
24 | | - |
25 | | -MOCK_PRIMING_GROUP_DATA = { |
26 | | - "id": "pgp_abc123", |
27 | | - "name": "door-detector-primer", |
28 | | - "is_global": False, |
29 | | - "canonical_query": "Is the door open?", |
30 | | - "active_pipeline_config": None, |
31 | | - "active_pipeline_base_mlbinary_key": "mlb_testkey1234", |
32 | | - "priming_group_specific_shadow_pipeline_configs": None, |
33 | | - "disable_shadow_pipelines": False, |
34 | | - "created_at": "2026-01-01T00:00:00Z", |
35 | | -} |
36 | | - |
37 | | - |
38 | | -def _mock_response(status_code=200, json_data=None): |
39 | | - resp = Mock() |
40 | | - resp.status_code = status_code |
41 | | - resp.json.return_value = json_data or {} |
42 | | - resp.raise_for_status.return_value = None |
43 | | - return resp |
44 | | - |
45 | | - |
46 | | -def _mock_error_response(status_code): |
47 | | - resp = Mock() |
48 | | - resp.status_code = status_code |
49 | | - resp.raise_for_status.side_effect = Exception(f"HTTP {status_code}") |
50 | | - return resp |
51 | | - |
52 | 18 |
|
53 | 19 | # --------------------------------------------------------------------------- |
54 | 20 | # list_detector_pipelines |
55 | 21 | # --------------------------------------------------------------------------- |
56 | 22 |
|
57 | 23 |
|
58 | | -def test_list_detector_pipelines_returns_pipelines(): |
59 | | - gl = ExperimentalApi() |
60 | | - with patch("requests.get") as mock_get: |
61 | | - mock_get.return_value = _mock_response(json_data={"results": [MOCK_PIPELINE_DATA], "count": 1}) |
62 | | - pipelines = gl.list_detector_pipelines("det_abc123") |
| 24 | +def test_list_detector_pipelines_returns_list(gl_experimental: ExperimentalApi, detector): |
| 25 | + """A freshly created detector has at least one pipeline.""" |
| 26 | + pipelines = gl_experimental.list_detector_pipelines(detector) |
| 27 | + assert isinstance(pipelines, list) |
| 28 | + assert len(pipelines) >= 1 |
| 29 | + assert all(isinstance(p, MLPipeline) for p in pipelines) |
63 | 30 |
|
64 | | - assert len(pipelines) == 1 |
65 | | - assert isinstance(pipelines[0], MLPipeline) |
66 | | - assert pipelines[0].id == "pipe_abc123" |
67 | | - assert pipelines[0].is_active_pipeline is True |
68 | 31 |
|
| 32 | +def test_list_detector_pipelines_accepts_detector_id_string(gl_experimental: ExperimentalApi, detector): |
| 33 | + """list_detector_pipelines should accept a raw ID string as well as a Detector object.""" |
| 34 | + by_obj = gl_experimental.list_detector_pipelines(detector) |
| 35 | + by_id = gl_experimental.list_detector_pipelines(detector.id) |
| 36 | + assert [p.id for p in by_obj] == [p.id for p in by_id] |
69 | 37 |
|
70 | | -def test_list_detector_pipelines_accepts_detector_object(): |
71 | | - gl = ExperimentalApi() |
72 | | - detector = Mock() |
73 | | - detector.id = "det_abc123" |
74 | | - with patch("requests.get") as mock_get: |
75 | | - mock_get.return_value = _mock_response(json_data={"results": [MOCK_PIPELINE_DATA], "count": 1}) |
76 | | - pipelines = gl.list_detector_pipelines(detector) |
77 | 38 |
|
78 | | - assert len(pipelines) == 1 |
79 | | - assert mock_get.call_args[0][0].endswith("/v1/detectors/det_abc123/pipelines") |
80 | | - |
81 | | - |
82 | | -def test_list_detector_pipelines_empty(): |
83 | | - gl = ExperimentalApi() |
84 | | - with patch("requests.get") as mock_get: |
85 | | - mock_get.return_value = _mock_response(json_data={"results": [], "count": 0}) |
86 | | - pipelines = gl.list_detector_pipelines("det_abc123") |
87 | | - |
88 | | - assert pipelines == [] |
89 | | - |
90 | | - |
91 | | -def test_list_detector_pipelines_404_raises_not_found(): |
92 | | - gl = ExperimentalApi() |
93 | | - with patch("requests.get") as mock_get: |
94 | | - mock_get.return_value = _mock_error_response(404) |
95 | | - with pytest.raises(NotFoundError, match="det_notexist"): |
96 | | - gl.list_detector_pipelines("det_notexist") |
| 39 | +def test_list_detector_pipelines_unknown_detector_raises(gl_experimental: ExperimentalApi): |
| 40 | + with pytest.raises(NotFoundError): |
| 41 | + gl_experimental.list_detector_pipelines("det_doesnotexist000000000000") |
97 | 42 |
|
98 | 43 |
|
99 | 44 | # --------------------------------------------------------------------------- |
100 | 45 | # list_priming_groups |
101 | 46 | # --------------------------------------------------------------------------- |
102 | 47 |
|
103 | 48 |
|
104 | | -def test_list_priming_groups_returns_groups(): |
105 | | - gl = ExperimentalApi() |
106 | | - with patch("requests.get") as mock_get: |
107 | | - mock_get.return_value = _mock_response(json_data={"results": [MOCK_PRIMING_GROUP_DATA], "count": 1}) |
108 | | - groups = gl.list_priming_groups() |
109 | | - |
110 | | - assert len(groups) == 1 |
111 | | - assert isinstance(groups[0], PrimingGroup) |
112 | | - assert groups[0].id == "pgp_abc123" |
113 | | - assert groups[0].name == "door-detector-primer" |
114 | | - |
115 | | - |
116 | | -def test_list_priming_groups_empty(): |
117 | | - gl = ExperimentalApi() |
118 | | - with patch("requests.get") as mock_get: |
119 | | - mock_get.return_value = _mock_response(json_data={"results": [], "count": 0}) |
120 | | - groups = gl.list_priming_groups() |
121 | | - |
122 | | - assert groups == [] |
123 | | - |
124 | | - |
125 | | -def test_list_priming_groups_hits_correct_url(): |
126 | | - gl = ExperimentalApi() |
127 | | - with patch("requests.get") as mock_get: |
128 | | - mock_get.return_value = _mock_response(json_data={"results": [], "count": 0}) |
129 | | - gl.list_priming_groups() |
130 | | - |
131 | | - called_url = mock_get.call_args[0][0] |
132 | | - assert called_url.endswith("/v1/priming-groups") |
| 49 | +def test_list_priming_groups_returns_list(gl_experimental: ExperimentalApi): |
| 50 | + groups = gl_experimental.list_priming_groups() |
| 51 | + assert isinstance(groups, list) |
| 52 | + assert all(isinstance(g, PrimingGroup) for g in groups) |
133 | 53 |
|
134 | 54 |
|
135 | 55 | # --------------------------------------------------------------------------- |
136 | | -# create_priming_group |
| 56 | +# create / get / delete (expensive — require a trained pipeline) |
137 | 57 | # --------------------------------------------------------------------------- |
138 | 58 |
|
139 | 59 |
|
140 | | -def test_create_priming_group_returns_group(): |
141 | | - gl = ExperimentalApi() |
142 | | - with patch("requests.post") as mock_post: |
143 | | - mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
144 | | - pg = gl.create_priming_group( |
145 | | - name="door-detector-primer", |
146 | | - source_ml_pipeline_id="pipe_abc123", |
147 | | - canonical_query="Is the door open?", |
148 | | - ) |
149 | | - |
150 | | - assert isinstance(pg, PrimingGroup) |
151 | | - assert pg.id == "pgp_abc123" |
152 | | - assert pg.name == "door-detector-primer" |
153 | | - assert pg.canonical_query == "Is the door open?" |
| 60 | +def _wait_for_trained_pipeline(gl_experimental: ExperimentalApi, detector, timeout: int = 45) -> MLPipeline: |
| 61 | + """ |
| 62 | + Submit the cat and dog test images, then poll until the active pipeline has a |
| 63 | + cached_vizlogic_key (i.e. has been trained). Raises TimeoutError if training |
| 64 | + doesn't complete within `timeout` seconds. |
| 65 | + """ |
| 66 | + gl_experimental.submit_image_query(detector, "test/assets/dog.jpeg", human_review="NEVER") |
| 67 | + gl_experimental.submit_image_query(detector, "test/assets/cat.jpeg", human_review="NEVER") |
154 | 68 |
|
| 69 | + deadline = time.monotonic() + timeout |
| 70 | + while time.monotonic() < deadline: |
| 71 | + pipelines = gl_experimental.list_detector_pipelines(detector) |
| 72 | + for p in pipelines: |
| 73 | + if p.is_active_pipeline and p.cached_vizlogic_key: |
| 74 | + return p |
| 75 | + time.sleep(5) |
155 | 76 |
|
156 | | -def test_create_priming_group_sends_correct_payload(): |
157 | | - gl = ExperimentalApi() |
158 | | - with patch("requests.post") as mock_post: |
159 | | - mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
160 | | - gl.create_priming_group( |
161 | | - name="door-detector-primer", |
162 | | - source_ml_pipeline_id="pipe_abc123", |
163 | | - canonical_query="Is the door open?", |
164 | | - disable_shadow_pipelines=True, |
165 | | - ) |
| 77 | + raise TimeoutError(f"Detector {detector.id} did not produce a trained pipeline within {timeout}s") |
166 | 78 |
|
167 | | - payload = mock_post.call_args[1]["json"] |
168 | | - assert payload["name"] == "door-detector-primer" |
169 | | - assert payload["source_ml_pipeline_id"] == "pipe_abc123" |
170 | | - assert payload["canonical_query"] == "Is the door open?" |
171 | | - assert payload["disable_shadow_pipelines"] is True |
172 | 79 |
|
| 80 | +@pytest.mark.expensive |
| 81 | +def test_create_priming_group(gl_experimental: ExperimentalApi, detector): |
| 82 | + trained = _wait_for_trained_pipeline(gl_experimental, detector) |
173 | 83 |
|
174 | | -def test_create_priming_group_omits_canonical_query_when_none(): |
175 | | - gl = ExperimentalApi() |
176 | | - with patch("requests.post") as mock_post: |
177 | | - mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
178 | | - gl.create_priming_group(name="primer", source_ml_pipeline_id="pipe_abc123") |
| 84 | + pg = gl_experimental.create_priming_group( |
| 85 | + name=f"test-primer-{detector.id}", |
| 86 | + source_ml_pipeline_id=trained.id, |
| 87 | + canonical_query="Is there a dog?", |
| 88 | + ) |
179 | 89 |
|
180 | | - payload = mock_post.call_args[1]["json"] |
181 | | - assert "canonical_query" not in payload |
| 90 | + assert isinstance(pg, PrimingGroup) |
| 91 | + assert pg.id.startswith("pgp_") |
| 92 | + assert pg.name == f"test-primer-{detector.id}" |
| 93 | + assert pg.canonical_query == "Is there a dog?" |
| 94 | + assert pg.is_global is False |
| 95 | + assert pg.active_pipeline_base_mlbinary_key is not None |
182 | 96 |
|
| 97 | + # cleanup |
| 98 | + gl_experimental.delete_priming_group(pg.id) |
183 | 99 |
|
184 | | -def test_create_priming_group_disable_shadow_pipelines_default_false(): |
185 | | - gl = ExperimentalApi() |
186 | | - with patch("requests.post") as mock_post: |
187 | | - mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
188 | | - gl.create_priming_group(name="primer", source_ml_pipeline_id="pipe_abc123") |
189 | 100 |
|
190 | | - payload = mock_post.call_args[1]["json"] |
191 | | - assert payload["disable_shadow_pipelines"] is False |
| 101 | +@pytest.mark.expensive |
| 102 | +def test_create_priming_group_disable_shadow_pipelines(gl_experimental: ExperimentalApi, detector): |
| 103 | + trained = _wait_for_trained_pipeline(gl_experimental, detector) |
192 | 104 |
|
| 105 | + pg = gl_experimental.create_priming_group( |
| 106 | + name=f"test-primer-noshadow-{detector.id}", |
| 107 | + source_ml_pipeline_id=trained.id, |
| 108 | + disable_shadow_pipelines=True, |
| 109 | + ) |
193 | 110 |
|
194 | | -# --------------------------------------------------------------------------- |
195 | | -# get_priming_group |
196 | | -# --------------------------------------------------------------------------- |
| 111 | + assert pg.disable_shadow_pipelines is True |
197 | 112 |
|
| 113 | + gl_experimental.delete_priming_group(pg.id) |
198 | 114 |
|
199 | | -def test_get_priming_group_returns_group(): |
200 | | - gl = ExperimentalApi() |
201 | | - with patch("requests.get") as mock_get: |
202 | | - mock_get.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
203 | | - pg = gl.get_priming_group("pgp_abc123") |
204 | 115 |
|
205 | | - assert isinstance(pg, PrimingGroup) |
206 | | - assert pg.id == "pgp_abc123" |
207 | | - assert pg.name == "door-detector-primer" |
| 116 | +@pytest.mark.expensive |
| 117 | +def test_get_priming_group(gl_experimental: ExperimentalApi, detector): |
| 118 | + trained = _wait_for_trained_pipeline(gl_experimental, detector) |
208 | 119 |
|
| 120 | + pg = gl_experimental.create_priming_group( |
| 121 | + name=f"test-primer-get-{detector.id}", |
| 122 | + source_ml_pipeline_id=trained.id, |
| 123 | + ) |
209 | 124 |
|
210 | | -def test_get_priming_group_404_raises_not_found(): |
211 | | - gl = ExperimentalApi() |
212 | | - with patch("requests.get") as mock_get: |
213 | | - mock_get.return_value = _mock_error_response(404) |
214 | | - with pytest.raises(NotFoundError, match="pgp_notexist"): |
215 | | - gl.get_priming_group("pgp_notexist") |
| 125 | + fetched = gl_experimental.get_priming_group(pg.id) |
| 126 | + assert fetched.id == pg.id |
| 127 | + assert fetched.name == pg.name |
216 | 128 |
|
| 129 | + gl_experimental.delete_priming_group(pg.id) |
217 | 130 |
|
218 | | -def test_get_priming_group_hits_correct_url(): |
219 | | - gl = ExperimentalApi() |
220 | | - with patch("requests.get") as mock_get: |
221 | | - mock_get.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA) |
222 | | - gl.get_priming_group("pgp_abc123") |
223 | 131 |
|
224 | | - called_url = mock_get.call_args[0][0] |
225 | | - assert called_url.endswith("/v1/priming-groups/pgp_abc123") |
| 132 | +@pytest.mark.expensive |
| 133 | +def test_get_priming_group_unknown_raises(gl_experimental: ExperimentalApi): |
| 134 | + with pytest.raises(NotFoundError): |
| 135 | + gl_experimental.get_priming_group("pgp_doesnotexist000000000000") |
226 | 136 |
|
227 | 137 |
|
228 | | -# --------------------------------------------------------------------------- |
229 | | -# delete_priming_group |
230 | | -# --------------------------------------------------------------------------- |
| 138 | +@pytest.mark.expensive |
| 139 | +def test_delete_priming_group(gl_experimental: ExperimentalApi, detector): |
| 140 | + trained = _wait_for_trained_pipeline(gl_experimental, detector) |
231 | 141 |
|
| 142 | + pg = gl_experimental.create_priming_group( |
| 143 | + name=f"test-primer-del-{detector.id}", |
| 144 | + source_ml_pipeline_id=trained.id, |
| 145 | + ) |
232 | 146 |
|
233 | | -def test_delete_priming_group_succeeds(): |
234 | | - gl = ExperimentalApi() |
235 | | - with patch("requests.delete") as mock_delete: |
236 | | - mock_delete.return_value = _mock_response(status_code=204) |
237 | | - gl.delete_priming_group("pgp_abc123") # should not raise |
| 147 | + gl_experimental.delete_priming_group(pg.id) |
238 | 148 |
|
239 | | - mock_delete.assert_called_once() |
| 149 | + with pytest.raises(NotFoundError): |
| 150 | + gl_experimental.get_priming_group(pg.id) |
240 | 151 |
|
241 | 152 |
|
242 | | -def test_delete_priming_group_404_raises_not_found(): |
243 | | - gl = ExperimentalApi() |
244 | | - with patch("requests.delete") as mock_delete: |
245 | | - mock_delete.return_value = _mock_error_response(404) |
246 | | - with pytest.raises(NotFoundError, match="pgp_notexist"): |
247 | | - gl.delete_priming_group("pgp_notexist") |
| 153 | +@pytest.mark.expensive |
| 154 | +def test_created_priming_group_appears_in_list(gl_experimental: ExperimentalApi, detector): |
| 155 | + trained = _wait_for_trained_pipeline(gl_experimental, detector) |
248 | 156 |
|
| 157 | + pg = gl_experimental.create_priming_group( |
| 158 | + name=f"test-primer-list-{detector.id}", |
| 159 | + source_ml_pipeline_id=trained.id, |
| 160 | + ) |
249 | 161 |
|
250 | | -def test_delete_priming_group_hits_correct_url(): |
251 | | - gl = ExperimentalApi() |
252 | | - with patch("requests.delete") as mock_delete: |
253 | | - mock_delete.return_value = _mock_response(status_code=204) |
254 | | - gl.delete_priming_group("pgp_abc123") |
| 162 | + groups = gl_experimental.list_priming_groups() |
| 163 | + assert any(g.id == pg.id for g in groups) |
255 | 164 |
|
256 | | - called_url = mock_delete.call_args[0][0] |
257 | | - assert called_url.endswith("/v1/priming-groups/pgp_abc123") |
| 165 | + gl_experimental.delete_priming_group(pg.id) |
0 commit comments