Skip to content

Commit 0afd38a

Browse files
committed
test(priming-group): add real integration tests; add expensive marker
1 parent 184d6a0 commit 0afd38a

3 files changed

Lines changed: 110 additions & 201 deletions

File tree

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ PYTEST=poetry run pytest -v
3838
# `make test TEST_ARGS="-k some_filter"`
3939
TEST_ARGS=
4040

41-
CLOUD_FILTERS = -m "not run_only_for_edge_endpoint"
42-
EDGE_FILTERS = -m "not skip_for_edge_endpoint"
41+
CLOUD_FILTERS = -m "not run_only_for_edge_endpoint and not expensive"
42+
EDGE_FILTERS = -m "not skip_for_edge_endpoint and not expensive"
4343

4444
# Record information about the slowest 25 tests (but don't show anything slower than 0.1 seconds)
4545
PROFILING_ARGS = \

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ trailing_comma_inline_array = true
8484
markers = [
8585
"skip_for_edge_endpoint",
8686
"run_only_for_edge_endpoint",
87+
"expensive: marks tests as expensive (long-running; excluded from default test runs)",
8788
]
8889

8990
[build-system]

test/unit/test_priming_groups.py

Lines changed: 107 additions & 199 deletions
Original file line numberDiff line numberDiff line change
@@ -1,257 +1,165 @@
1-
from unittest.mock import Mock, patch
1+
"""
2+
Tests for PrimingGroup and ML pipeline methods on ExperimentalApi.
3+
4+
The create/get/delete priming group tests are marked @pytest.mark.expensive because they
5+
require training a detector (submit images, wait ~45 s) before a pipeline has a
6+
cached_vizlogic_key that can seed a PrimingGroup. Run them explicitly with:
7+
8+
pytest -m expensive test/unit/test_priming_groups.py
9+
"""
10+
11+
import time
212

313
import pytest
414
from groundlight import ExperimentalApi
515
from groundlight.internalapi import NotFoundError
616
from model import MLPipeline, PrimingGroup
717

8-
# ---------------------------------------------------------------------------
9-
# Helpers
10-
# ---------------------------------------------------------------------------
11-
12-
MOCK_PIPELINE_DATA = {
13-
"id": "pipe_abc123",
14-
"pipeline_config": "never-review",
15-
"cached_vizlogic_key": "mlb_testkey1234",
16-
"is_active_pipeline": True,
17-
"is_edge_pipeline": False,
18-
"is_unclear_pipeline": False,
19-
"is_oodd_pipeline": False,
20-
"is_enabled": True,
21-
"created_at": "2026-01-01T00:00:00Z",
22-
"trained_at": None,
23-
}
24-
25-
MOCK_PRIMING_GROUP_DATA = {
26-
"id": "pgp_abc123",
27-
"name": "door-detector-primer",
28-
"is_global": False,
29-
"canonical_query": "Is the door open?",
30-
"active_pipeline_config": None,
31-
"active_pipeline_base_mlbinary_key": "mlb_testkey1234",
32-
"priming_group_specific_shadow_pipeline_configs": None,
33-
"disable_shadow_pipelines": False,
34-
"created_at": "2026-01-01T00:00:00Z",
35-
}
36-
37-
38-
def _mock_response(status_code=200, json_data=None):
39-
resp = Mock()
40-
resp.status_code = status_code
41-
resp.json.return_value = json_data or {}
42-
resp.raise_for_status.return_value = None
43-
return resp
44-
45-
46-
def _mock_error_response(status_code):
47-
resp = Mock()
48-
resp.status_code = status_code
49-
resp.raise_for_status.side_effect = Exception(f"HTTP {status_code}")
50-
return resp
51-
5218

5319
# ---------------------------------------------------------------------------
5420
# list_detector_pipelines
5521
# ---------------------------------------------------------------------------
5622

5723

58-
def test_list_detector_pipelines_returns_pipelines():
59-
gl = ExperimentalApi()
60-
with patch("requests.get") as mock_get:
61-
mock_get.return_value = _mock_response(json_data={"results": [MOCK_PIPELINE_DATA], "count": 1})
62-
pipelines = gl.list_detector_pipelines("det_abc123")
24+
def test_list_detector_pipelines_returns_list(gl_experimental: ExperimentalApi, detector):
25+
"""A freshly created detector has at least one pipeline."""
26+
pipelines = gl_experimental.list_detector_pipelines(detector)
27+
assert isinstance(pipelines, list)
28+
assert len(pipelines) >= 1
29+
assert all(isinstance(p, MLPipeline) for p in pipelines)
6330

64-
assert len(pipelines) == 1
65-
assert isinstance(pipelines[0], MLPipeline)
66-
assert pipelines[0].id == "pipe_abc123"
67-
assert pipelines[0].is_active_pipeline is True
6831

32+
def test_list_detector_pipelines_accepts_detector_id_string(gl_experimental: ExperimentalApi, detector):
33+
"""list_detector_pipelines should accept a raw ID string as well as a Detector object."""
34+
by_obj = gl_experimental.list_detector_pipelines(detector)
35+
by_id = gl_experimental.list_detector_pipelines(detector.id)
36+
assert [p.id for p in by_obj] == [p.id for p in by_id]
6937

70-
def test_list_detector_pipelines_accepts_detector_object():
71-
gl = ExperimentalApi()
72-
detector = Mock()
73-
detector.id = "det_abc123"
74-
with patch("requests.get") as mock_get:
75-
mock_get.return_value = _mock_response(json_data={"results": [MOCK_PIPELINE_DATA], "count": 1})
76-
pipelines = gl.list_detector_pipelines(detector)
7738

78-
assert len(pipelines) == 1
79-
assert mock_get.call_args[0][0].endswith("/v1/detectors/det_abc123/pipelines")
80-
81-
82-
def test_list_detector_pipelines_empty():
83-
gl = ExperimentalApi()
84-
with patch("requests.get") as mock_get:
85-
mock_get.return_value = _mock_response(json_data={"results": [], "count": 0})
86-
pipelines = gl.list_detector_pipelines("det_abc123")
87-
88-
assert pipelines == []
89-
90-
91-
def test_list_detector_pipelines_404_raises_not_found():
92-
gl = ExperimentalApi()
93-
with patch("requests.get") as mock_get:
94-
mock_get.return_value = _mock_error_response(404)
95-
with pytest.raises(NotFoundError, match="det_notexist"):
96-
gl.list_detector_pipelines("det_notexist")
39+
def test_list_detector_pipelines_unknown_detector_raises(gl_experimental: ExperimentalApi):
40+
with pytest.raises(NotFoundError):
41+
gl_experimental.list_detector_pipelines("det_doesnotexist000000000000")
9742

9843

9944
# ---------------------------------------------------------------------------
10045
# list_priming_groups
10146
# ---------------------------------------------------------------------------
10247

10348

104-
def test_list_priming_groups_returns_groups():
105-
gl = ExperimentalApi()
106-
with patch("requests.get") as mock_get:
107-
mock_get.return_value = _mock_response(json_data={"results": [MOCK_PRIMING_GROUP_DATA], "count": 1})
108-
groups = gl.list_priming_groups()
109-
110-
assert len(groups) == 1
111-
assert isinstance(groups[0], PrimingGroup)
112-
assert groups[0].id == "pgp_abc123"
113-
assert groups[0].name == "door-detector-primer"
114-
115-
116-
def test_list_priming_groups_empty():
117-
gl = ExperimentalApi()
118-
with patch("requests.get") as mock_get:
119-
mock_get.return_value = _mock_response(json_data={"results": [], "count": 0})
120-
groups = gl.list_priming_groups()
121-
122-
assert groups == []
123-
124-
125-
def test_list_priming_groups_hits_correct_url():
126-
gl = ExperimentalApi()
127-
with patch("requests.get") as mock_get:
128-
mock_get.return_value = _mock_response(json_data={"results": [], "count": 0})
129-
gl.list_priming_groups()
130-
131-
called_url = mock_get.call_args[0][0]
132-
assert called_url.endswith("/v1/priming-groups")
49+
def test_list_priming_groups_returns_list(gl_experimental: ExperimentalApi):
50+
groups = gl_experimental.list_priming_groups()
51+
assert isinstance(groups, list)
52+
assert all(isinstance(g, PrimingGroup) for g in groups)
13353

13454

13555
# ---------------------------------------------------------------------------
136-
# create_priming_group
56+
# create / get / delete (expensive — require a trained pipeline)
13757
# ---------------------------------------------------------------------------
13858

13959

140-
def test_create_priming_group_returns_group():
141-
gl = ExperimentalApi()
142-
with patch("requests.post") as mock_post:
143-
mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
144-
pg = gl.create_priming_group(
145-
name="door-detector-primer",
146-
source_ml_pipeline_id="pipe_abc123",
147-
canonical_query="Is the door open?",
148-
)
149-
150-
assert isinstance(pg, PrimingGroup)
151-
assert pg.id == "pgp_abc123"
152-
assert pg.name == "door-detector-primer"
153-
assert pg.canonical_query == "Is the door open?"
60+
def _wait_for_trained_pipeline(gl_experimental: ExperimentalApi, detector, timeout: int = 45) -> MLPipeline:
61+
"""
62+
Submit the cat and dog test images, then poll until the active pipeline has a
63+
cached_vizlogic_key (i.e. has been trained). Raises TimeoutError if training
64+
doesn't complete within `timeout` seconds.
65+
"""
66+
gl_experimental.submit_image_query(detector, "test/assets/dog.jpeg", human_review="NEVER")
67+
gl_experimental.submit_image_query(detector, "test/assets/cat.jpeg", human_review="NEVER")
15468

69+
deadline = time.monotonic() + timeout
70+
while time.monotonic() < deadline:
71+
pipelines = gl_experimental.list_detector_pipelines(detector)
72+
for p in pipelines:
73+
if p.is_active_pipeline and p.cached_vizlogic_key:
74+
return p
75+
time.sleep(5)
15576

156-
def test_create_priming_group_sends_correct_payload():
157-
gl = ExperimentalApi()
158-
with patch("requests.post") as mock_post:
159-
mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
160-
gl.create_priming_group(
161-
name="door-detector-primer",
162-
source_ml_pipeline_id="pipe_abc123",
163-
canonical_query="Is the door open?",
164-
disable_shadow_pipelines=True,
165-
)
77+
raise TimeoutError(f"Detector {detector.id} did not produce a trained pipeline within {timeout}s")
16678

167-
payload = mock_post.call_args[1]["json"]
168-
assert payload["name"] == "door-detector-primer"
169-
assert payload["source_ml_pipeline_id"] == "pipe_abc123"
170-
assert payload["canonical_query"] == "Is the door open?"
171-
assert payload["disable_shadow_pipelines"] is True
17279

80+
@pytest.mark.expensive
81+
def test_create_priming_group(gl_experimental: ExperimentalApi, detector):
82+
trained = _wait_for_trained_pipeline(gl_experimental, detector)
17383

174-
def test_create_priming_group_omits_canonical_query_when_none():
175-
gl = ExperimentalApi()
176-
with patch("requests.post") as mock_post:
177-
mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
178-
gl.create_priming_group(name="primer", source_ml_pipeline_id="pipe_abc123")
84+
pg = gl_experimental.create_priming_group(
85+
name=f"test-primer-{detector.id}",
86+
source_ml_pipeline_id=trained.id,
87+
canonical_query="Is there a dog?",
88+
)
17989

180-
payload = mock_post.call_args[1]["json"]
181-
assert "canonical_query" not in payload
90+
assert isinstance(pg, PrimingGroup)
91+
assert pg.id.startswith("pgp_")
92+
assert pg.name == f"test-primer-{detector.id}"
93+
assert pg.canonical_query == "Is there a dog?"
94+
assert pg.is_global is False
95+
assert pg.active_pipeline_base_mlbinary_key is not None
18296

97+
# cleanup
98+
gl_experimental.delete_priming_group(pg.id)
18399

184-
def test_create_priming_group_disable_shadow_pipelines_default_false():
185-
gl = ExperimentalApi()
186-
with patch("requests.post") as mock_post:
187-
mock_post.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
188-
gl.create_priming_group(name="primer", source_ml_pipeline_id="pipe_abc123")
189100

190-
payload = mock_post.call_args[1]["json"]
191-
assert payload["disable_shadow_pipelines"] is False
101+
@pytest.mark.expensive
102+
def test_create_priming_group_disable_shadow_pipelines(gl_experimental: ExperimentalApi, detector):
103+
trained = _wait_for_trained_pipeline(gl_experimental, detector)
192104

105+
pg = gl_experimental.create_priming_group(
106+
name=f"test-primer-noshadow-{detector.id}",
107+
source_ml_pipeline_id=trained.id,
108+
disable_shadow_pipelines=True,
109+
)
193110

194-
# ---------------------------------------------------------------------------
195-
# get_priming_group
196-
# ---------------------------------------------------------------------------
111+
assert pg.disable_shadow_pipelines is True
197112

113+
gl_experimental.delete_priming_group(pg.id)
198114

199-
def test_get_priming_group_returns_group():
200-
gl = ExperimentalApi()
201-
with patch("requests.get") as mock_get:
202-
mock_get.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
203-
pg = gl.get_priming_group("pgp_abc123")
204115

205-
assert isinstance(pg, PrimingGroup)
206-
assert pg.id == "pgp_abc123"
207-
assert pg.name == "door-detector-primer"
116+
@pytest.mark.expensive
117+
def test_get_priming_group(gl_experimental: ExperimentalApi, detector):
118+
trained = _wait_for_trained_pipeline(gl_experimental, detector)
208119

120+
pg = gl_experimental.create_priming_group(
121+
name=f"test-primer-get-{detector.id}",
122+
source_ml_pipeline_id=trained.id,
123+
)
209124

210-
def test_get_priming_group_404_raises_not_found():
211-
gl = ExperimentalApi()
212-
with patch("requests.get") as mock_get:
213-
mock_get.return_value = _mock_error_response(404)
214-
with pytest.raises(NotFoundError, match="pgp_notexist"):
215-
gl.get_priming_group("pgp_notexist")
125+
fetched = gl_experimental.get_priming_group(pg.id)
126+
assert fetched.id == pg.id
127+
assert fetched.name == pg.name
216128

129+
gl_experimental.delete_priming_group(pg.id)
217130

218-
def test_get_priming_group_hits_correct_url():
219-
gl = ExperimentalApi()
220-
with patch("requests.get") as mock_get:
221-
mock_get.return_value = _mock_response(json_data=MOCK_PRIMING_GROUP_DATA)
222-
gl.get_priming_group("pgp_abc123")
223131

224-
called_url = mock_get.call_args[0][0]
225-
assert called_url.endswith("/v1/priming-groups/pgp_abc123")
132+
@pytest.mark.expensive
133+
def test_get_priming_group_unknown_raises(gl_experimental: ExperimentalApi):
134+
with pytest.raises(NotFoundError):
135+
gl_experimental.get_priming_group("pgp_doesnotexist000000000000")
226136

227137

228-
# ---------------------------------------------------------------------------
229-
# delete_priming_group
230-
# ---------------------------------------------------------------------------
138+
@pytest.mark.expensive
139+
def test_delete_priming_group(gl_experimental: ExperimentalApi, detector):
140+
trained = _wait_for_trained_pipeline(gl_experimental, detector)
231141

142+
pg = gl_experimental.create_priming_group(
143+
name=f"test-primer-del-{detector.id}",
144+
source_ml_pipeline_id=trained.id,
145+
)
232146

233-
def test_delete_priming_group_succeeds():
234-
gl = ExperimentalApi()
235-
with patch("requests.delete") as mock_delete:
236-
mock_delete.return_value = _mock_response(status_code=204)
237-
gl.delete_priming_group("pgp_abc123") # should not raise
147+
gl_experimental.delete_priming_group(pg.id)
238148

239-
mock_delete.assert_called_once()
149+
with pytest.raises(NotFoundError):
150+
gl_experimental.get_priming_group(pg.id)
240151

241152

242-
def test_delete_priming_group_404_raises_not_found():
243-
gl = ExperimentalApi()
244-
with patch("requests.delete") as mock_delete:
245-
mock_delete.return_value = _mock_error_response(404)
246-
with pytest.raises(NotFoundError, match="pgp_notexist"):
247-
gl.delete_priming_group("pgp_notexist")
153+
@pytest.mark.expensive
154+
def test_created_priming_group_appears_in_list(gl_experimental: ExperimentalApi, detector):
155+
trained = _wait_for_trained_pipeline(gl_experimental, detector)
248156

157+
pg = gl_experimental.create_priming_group(
158+
name=f"test-primer-list-{detector.id}",
159+
source_ml_pipeline_id=trained.id,
160+
)
249161

250-
def test_delete_priming_group_hits_correct_url():
251-
gl = ExperimentalApi()
252-
with patch("requests.delete") as mock_delete:
253-
mock_delete.return_value = _mock_response(status_code=204)
254-
gl.delete_priming_group("pgp_abc123")
162+
groups = gl_experimental.list_priming_groups()
163+
assert any(g.id == pg.id for g in groups)
255164

256-
called_url = mock_delete.call_args[0][0]
257-
assert called_url.endswith("/v1/priming-groups/pgp_abc123")
165+
gl_experimental.delete_priming_group(pg.id)

0 commit comments

Comments
 (0)