@@ -57,10 +57,10 @@ def transform(self, source: np.ndarray) -> np.ndarray:
5757class PiecewiseLinearCalibration (Calibration ):
5858 def __init__ (
5959 self ,
60- number_of_splits : int = 20 ,
60+ number_of_splits : int = 10 ,
6161 extrapolate : bool = True ,
6262 use_median : bool = False ,
63- min_samples_per_segment : int = 10 ,
63+ min_samples_per_segment : int = 20 ,
6464 ) -> None :
6565 """
6666 Piece-wise linear calibration based on per-split anchors.
@@ -123,14 +123,20 @@ def fit(self, target: np.ndarray, source: np.ndarray) -> None:
123123 raise CalibrationError ("Source values have zero or invalid range; cannot calibrate." )
124124
125125 boundaries = np .linspace (cal_min , cal_max , self .number_of_splits + 1 , dtype = np .float32 )
126- starts : np .ndarray = np .searchsorted (source , boundaries [:- 1 ], side = "left" ) # type: ignore[var-annotated]
127- ends : np .ndarray = np .searchsorted (source , boundaries [1 :], side = "left" ) # type: ignore[var-annotated]
128-
129- # Filter out sparse segments
130- counts = ends - starts
131- valid_segments = counts >= self .min_samples_per_segment
132- starts = starts [valid_segments ]
133- ends = ends [valid_segments ]
126+ starts_raw : np .ndarray = np .searchsorted (source , boundaries [:- 1 ], side = "left" ) # type: ignore[var-annotated]
127+ ends_raw : np .ndarray = np .searchsorted (source , boundaries [1 :], side = "left" ) # type: ignore[var-annotated]
128+
129+ # Merge adjacent sparse segments by assigning each segment to a group based on
130+ # how many min_samples-sized chunks the cumulative count has crossed so far.
131+ # Segments whose cumulative count falls within the same chunk share a group id
132+ # and are merged into a single anchor.
133+ counts = ends_raw - starts_raw
134+ group_ids = (np .cumsum (counts ) - 1 ) // self .min_samples_per_segment
135+ group_start_indices = np .concatenate (([0 ], np .flatnonzero (np .diff (group_ids )) + 1 ))
136+ group_end_indices = np .concatenate ((group_start_indices [1 :] - 1 , [len (starts_raw ) - 1 ]))
137+
138+ starts = starts_raw [group_start_indices ]
139+ ends = ends_raw [group_end_indices ]
134140
135141 # Compute anchors for all segments
136142 aggregate_func = np .median if self .use_median else np .mean
0 commit comments