Skip to content

Commit 0fbd05e

Browse files
author
Tim Huff
committed
code cleanup
1 parent 18a33b2 commit 0fbd05e

2 files changed

Lines changed: 80 additions & 106 deletions

File tree

src/groundlight/edge/config.py

Lines changed: 75 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Optional, Union
22

33
from model import Detector
4-
from pydantic import BaseModel, ConfigDict, Field, model_validator
4+
from pydantic import BaseModel, ConfigDict, Field, model_serializer, model_validator
55
from typing_extensions import Self
66

77

@@ -74,63 +74,6 @@ class DetectorConfig(BaseModel):
7474
edge_inference_config: str = Field(..., description="Config for edge inference.")
7575

7676

77-
def _validate_detector_config_state(
78-
edge_inference_configs: dict[str, InferenceConfig], detectors: list[DetectorConfig]
79-
) -> None:
80-
for name, config in edge_inference_configs.items():
81-
if name != config.name:
82-
raise ValueError(f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'.")
83-
84-
seen_detector_ids = set()
85-
duplicate_detector_ids = set()
86-
for detector_config in detectors:
87-
detector_id = detector_config.detector_id
88-
if detector_id in seen_detector_ids:
89-
duplicate_detector_ids.add(detector_id)
90-
else:
91-
seen_detector_ids.add(detector_id)
92-
if duplicate_detector_ids:
93-
duplicates = ", ".join(sorted(duplicate_detector_ids))
94-
raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.")
95-
96-
for detector_config in detectors:
97-
if detector_config.edge_inference_config not in edge_inference_configs:
98-
raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.")
99-
100-
101-
def _add_detector_to_state(
102-
edge_inference_configs: dict[str, InferenceConfig],
103-
detectors: list[DetectorConfig],
104-
detector: Union[str, Detector],
105-
edge_inference_config: Union[str, InferenceConfig],
106-
) -> DetectorConfig:
107-
detector_id = detector.id if isinstance(detector, Detector) else detector
108-
if any(existing.detector_id == detector_id for existing in detectors):
109-
raise ValueError(f"A detector with ID '{detector_id}' already exists.")
110-
if isinstance(edge_inference_config, InferenceConfig):
111-
config = edge_inference_config
112-
existing = edge_inference_configs.get(config.name)
113-
if existing is None:
114-
edge_inference_configs[config.name] = config
115-
elif existing != config:
116-
raise ValueError(f"A different inference config named '{config.name}' is already registered.")
117-
config_name = config.name
118-
else:
119-
config_name = edge_inference_config
120-
if config_name not in edge_inference_configs:
121-
raise ValueError(
122-
f"Edge inference config '{config_name}' not defined. "
123-
f"Available configs: {list(edge_inference_configs.keys())}"
124-
)
125-
126-
detector_config = DetectorConfig(
127-
detector_id=detector_id,
128-
edge_inference_config=config_name,
129-
)
130-
detectors.append(detector_config)
131-
return detector_config
132-
133-
13477
class DetectorsConfig(BaseModel):
13578
"""
13679
Detector and inference-config mappings for edge inference.
@@ -141,11 +84,59 @@ class DetectorsConfig(BaseModel):
14184

14285
@model_validator(mode="after")
14386
def validate_inference_configs(self):
144-
_validate_detector_config_state(self.edge_inference_configs, self.detectors)
87+
"""
88+
Validates detector config state.
89+
Raises ValueError if dict keys mismatch InferenceConfig.name, detector IDs are duplicated,
90+
or any detector references an undefined inference config.
91+
"""
92+
for name, config in self.edge_inference_configs.items():
93+
if name != config.name:
94+
raise ValueError(f"Edge inference config key '{name}' must match InferenceConfig.name '{config.name}'.")
95+
96+
seen_detector_ids = set()
97+
duplicate_detector_ids = set()
98+
for detector_config in self.detectors:
99+
detector_id = detector_config.detector_id
100+
if detector_id in seen_detector_ids:
101+
duplicate_detector_ids.add(detector_id)
102+
else:
103+
seen_detector_ids.add(detector_id)
104+
if duplicate_detector_ids:
105+
duplicates = ", ".join(sorted(duplicate_detector_ids))
106+
raise ValueError(f"Duplicate detector IDs are not allowed: {duplicates}.")
107+
108+
for detector_config in self.detectors:
109+
if detector_config.edge_inference_config not in self.edge_inference_configs:
110+
raise ValueError(f"Edge inference config '{detector_config.edge_inference_config}' not defined.")
145111
return self
146112

147-
def add_detector(self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]) -> None:
148-
_add_detector_to_state(self.edge_inference_configs, self.detectors, detector, edge_inference_config)
113+
def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None:
114+
"""Add a detector with the given inference config. Accepts detector ID or Detector object."""
115+
detector_id = detector.id if isinstance(detector, Detector) else detector
116+
if any(existing.detector_id == detector_id for existing in self.detectors):
117+
raise ValueError(f"A detector with ID '{detector_id}' already exists.")
118+
119+
existing = self.edge_inference_configs.get(edge_inference_config.name)
120+
if existing is None:
121+
self.edge_inference_configs[edge_inference_config.name] = edge_inference_config
122+
elif existing != edge_inference_config:
123+
raise ValueError(
124+
f"A different inference config named '{edge_inference_config.name}' is already registered."
125+
)
126+
127+
self.detectors.append(
128+
DetectorConfig(detector_id=detector_id, edge_inference_config=edge_inference_config.name)
129+
)
130+
131+
132+
def to_payload(self) -> dict[str, object]:
133+
"""Return flattened detector payload used by edge-endpoint config HTTP APIs."""
134+
return {
135+
"edge_inference_configs": {
136+
name: config.model_dump() for name, config in self.edge_inference_configs.items()
137+
},
138+
"detectors": [detector.model_dump() for detector in self.detectors],
139+
}
149140

150141

151142
class EdgeEndpointConfig(BaseModel):
@@ -154,27 +145,29 @@ class EdgeEndpointConfig(BaseModel):
154145
"""
155146

156147
global_config: GlobalConfig = Field(default_factory=GlobalConfig)
157-
edge_inference_configs: dict[str, InferenceConfig] = Field(default_factory=dict)
158-
detectors: list[DetectorConfig] = Field(default_factory=list)
159-
160-
@model_validator(mode="after")
161-
def validate_inference_configs(self):
162-
_validate_detector_config_state(self.edge_inference_configs, self.detectors)
163-
return self
164-
165-
def add_detector(self, detector: Union[str, Detector], edge_inference_config: Union[str, InferenceConfig]) -> None:
166-
_add_detector_to_state(self.edge_inference_configs, self.detectors, detector, edge_inference_config)
167-
168-
@classmethod
169-
def from_detectors_config(
170-
cls, detectors_config: "DetectorsConfig", global_config: Optional[GlobalConfig] = None
171-
) -> "EdgeEndpointConfig":
172-
copied_config = detectors_config.model_copy(deep=True)
173-
return cls(
174-
global_config=global_config or GlobalConfig(),
175-
edge_inference_configs=copied_config.edge_inference_configs,
176-
detectors=copied_config.detectors,
177-
)
148+
detectors_config: DetectorsConfig = Field(default_factory=DetectorsConfig)
149+
150+
@property
151+
def edge_inference_configs(self) -> dict[str, InferenceConfig]:
152+
"""Convenience accessor for detector inference config map."""
153+
return self.detectors_config.edge_inference_configs
154+
155+
@property
156+
def detectors(self) -> list[DetectorConfig]:
157+
"""Convenience accessor for detector assignments."""
158+
return self.detectors_config.detectors
159+
160+
@model_serializer(mode="plain")
161+
def serialize(self):
162+
"""Serialize to the flattened shape expected by edge-endpoint configs."""
163+
return {
164+
"global_config": self.global_config.model_dump(),
165+
**self.detectors_config.to_payload(),
166+
}
167+
168+
def add_detector(self, detector: Union[str, Detector], edge_inference_config: InferenceConfig) -> None:
169+
"""Add a detector with the given inference config. Accepts detector ID or Detector object."""
170+
self.detectors_config.add_detector(detector, edge_inference_config)
178171

179172

180173
# Preset inference configs matching the standard edge-endpoint defaults.

test/unit/test_edge_config.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
InferenceConfig,
1515
)
1616

17-
ONE_DETECTOR = 1
1817
TWO_DETECTORS = 2
1918
THREE_DETECTORS = 3
2019
CUSTOM_REFRESH_RATE = 10.0
@@ -119,13 +118,6 @@ def test_add_detector_accepts_detector_object():
119118
assert [detector.detector_id for detector in config.detectors] == ["det_1"]
120119

121120

122-
def test_add_detector_accepts_string_inference_config_name():
123-
config = EdgeEndpointConfig()
124-
config.edge_inference_configs["default"] = DEFAULT
125-
config.add_detector("det_1", "default")
126-
127-
assert [detector.edge_inference_config for detector in config.detectors] == ["default"]
128-
129121

130122
def test_disabled_preset_can_be_used():
131123
config = EdgeEndpointConfig()
@@ -135,26 +127,15 @@ def test_disabled_preset_can_be_used():
135127
assert config.edge_inference_configs["disabled"] == DISABLED
136128

137129

138-
def test_from_detectors_config_copies_detector_data():
139-
detectors_config = DetectorsConfig()
140-
detectors_config.add_detector("det_1", DEFAULT)
141-
142-
config = EdgeEndpointConfig.from_detectors_config(detectors_config)
143-
detectors_config.add_detector("det_2", DEFAULT)
144-
145-
assert len(config.detectors) == ONE_DETECTOR
146-
assert len(detectors_config.detectors) == TWO_DETECTORS
147-
148-
149-
def test_from_detectors_config_uses_custom_global_config():
130+
def test_detectors_config_to_payload_shape():
150131
detectors_config = DetectorsConfig()
151132
detectors_config.add_detector("det_1", DEFAULT)
152-
custom_global_config = GlobalConfig(refresh_rate=CUSTOM_REFRESH_RATE, confident_audit_rate=CUSTOM_AUDIT_RATE)
133+
detectors_config.add_detector("det_2", NO_CLOUD)
153134

154-
config = EdgeEndpointConfig.from_detectors_config(detectors_config, global_config=custom_global_config)
135+
payload = detectors_config.to_payload()
155136

156-
assert config.global_config == custom_global_config
157-
assert len(config.detectors) == ONE_DETECTOR
137+
assert len(payload["detectors"]) == TWO_DETECTORS
138+
assert set(payload["edge_inference_configs"].keys()) == {"default", "no_cloud"}
158139

159140

160141
def test_model_dump_shape_for_edge_endpoint_config():

0 commit comments

Comments
 (0)