Skip to content

Commit 49abc72

Browse files
committed
Improve piece-wise linear calibration:
Introduce min_samples_per_segment option to avoid overfitting in sparsely populated sections; set sensible default parameters to 10 splits and 20 min_samples_per_segment.
1 parent fe9bb73 commit 49abc72

1 file changed

Lines changed: 16 additions & 10 deletions

File tree

deeplc/calibration.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def transform(self, source: np.ndarray) -> np.ndarray:
5757
class PiecewiseLinearCalibration(Calibration):
5858
def __init__(
5959
self,
60-
number_of_splits: int = 20,
60+
number_of_splits: int = 10,
6161
extrapolate: bool = True,
6262
use_median: bool = False,
63-
min_samples_per_segment: int = 10,
63+
min_samples_per_segment: int = 20,
6464
) -> None:
6565
"""
6666
Piece-wise linear calibration based on per-split anchors.
@@ -123,14 +123,20 @@ def fit(self, target: np.ndarray, source: np.ndarray) -> None:
123123
raise CalibrationError("Source values have zero or invalid range; cannot calibrate.")
124124

125125
boundaries = np.linspace(cal_min, cal_max, self.number_of_splits + 1, dtype=np.float32)
126-
starts: np.ndarray = np.searchsorted(source, boundaries[:-1], side="left") # type: ignore[var-annotated]
127-
ends: np.ndarray = np.searchsorted(source, boundaries[1:], side="left") # type: ignore[var-annotated]
128-
129-
# Filter out sparse segments
130-
counts = ends - starts
131-
valid_segments = counts >= self.min_samples_per_segment
132-
starts = starts[valid_segments]
133-
ends = ends[valid_segments]
126+
starts_raw: np.ndarray = np.searchsorted(source, boundaries[:-1], side="left") # type: ignore[var-annotated]
127+
ends_raw: np.ndarray = np.searchsorted(source, boundaries[1:], side="left") # type: ignore[var-annotated]
128+
129+
# Merge adjacent sparse segments by assigning each segment to a group based on
130+
# how many min_samples-sized chunks the cumulative count has crossed so far.
131+
# Segments whose cumulative count falls within the same chunk share a group id
132+
# and are merged into a single anchor.
133+
counts = ends_raw - starts_raw
134+
group_ids = (np.cumsum(counts) - 1) // self.min_samples_per_segment
135+
group_start_indices = np.concatenate(([0], np.flatnonzero(np.diff(group_ids)) + 1))
136+
group_end_indices = np.concatenate((group_start_indices[1:] - 1, [len(starts_raw) - 1]))
137+
138+
starts = starts_raw[group_start_indices]
139+
ends = ends_raw[group_end_indices]
134140

135141
# Compute anchors for all segments
136142
aggregate_func = np.median if self.use_median else np.mean

0 commit comments

Comments
 (0)