1- # pylint: disable=redefined-builtin,no-name-in-module,no-member
1+ # pylint: disable=redefined-builtin,no-name-in-module,no-member,duplicate-code
2+ import warnings
3+
4+ import pyrecest .backend
25from pyrecest .backend import all , empty , full , repeat , squeeze , stack
36from scipy .optimize import linear_sum_assignment
47from scipy .spatial .distance import cdist
811
912
1013class GlobalNearestNeighbor (AbstractNearestNeighborTracker ):
14+ """Global nearest-neighbor tracker for linear/Gaussian multitarget tracking.
15+
16+ Besides the built-in geometric association costs, this implementation can
17+ optionally fuse an externally computed ``pairwise_cost_matrix`` of shape
18+ ``(n_targets, n_meas)``. This is useful for domains such as longitudinal
19+ calcium-imaging cell tracking where association should depend on arbitrary
20+ pairwise cues like ROI overlap, footprint correlation, or appearance
21+ embeddings in addition to centroid distance.
22+ """
23+
1124 def __init__ (
1225 self ,
1326 initial_prior = None ,
1427 association_param = None ,
1528 log_prior_estimates = True ,
1629 log_posterior_estimates = True ,
1730 ):
31+ default_association_param = {
32+ "distance_metric_pos" : "Mahalanobis" ,
33+ "square_dist" : True ,
34+ "max_new_tracks" : 10 ,
35+ "gating_distance_threshold" : chi2 .ppf (0.999 , 2 ) ** 2 ,
36+ "pairwise_cost_weight" : 1.0 ,
37+ }
1838 if association_param is None :
39+ association_param = default_association_param
40+ else :
1941 association_param = {
20- "distance_metric_pos" : "Mahalanobis" ,
21- "square_dist" : True ,
22- "max_new_tracks" : 10 ,
23- "gating_distance_threshold" : chi2 .ppf (0.999 , 2 ) ** 2 ,
42+ ** default_association_param ,
43+ ** association_param ,
2444 }
2545
2646 super ().__init__ (
@@ -30,14 +50,51 @@ def __init__(
3050 log_posterior_estimates = log_posterior_estimates ,
3151 )
3252
33- # pylint: disable=too-many-locals
53+ @staticmethod
54+ def _validate_pairwise_cost_matrix (pairwise_cost_matrix , n_targets , n_meas ):
55+ if pairwise_cost_matrix is None :
56+ return None
57+ if pairwise_cost_matrix .shape != (n_targets , n_meas ):
58+ raise ValueError (
59+ "pairwise_cost_matrix must have shape "
60+ f"({ n_targets } , { n_meas } ), got { pairwise_cost_matrix .shape } ."
61+ )
62+ return pairwise_cost_matrix
63+
64+ def _apply_pairwise_cost_matrix (self , dists , pairwise_cost_matrix ):
65+ if pairwise_cost_matrix is None :
66+ return dists
67+ pairwise_cost_weight = self .association_param .get ("pairwise_cost_weight" , 1.0 )
68+ if pairwise_cost_weight == 0.0 :
69+ return dists
70+ return dists + pairwise_cost_weight * pairwise_cost_matrix
71+
72+ # pylint: disable=too-many-locals,too-many-positional-arguments
3473 def find_association (
3574 self ,
3675 measurements ,
3776 measurement_matrix ,
3877 cov_mats_meas ,
3978 warn_on_no_meas_for_track = True ,
79+ pairwise_cost_matrix = None ,
4080 ):
81+ """Find the minimum-cost measurement-to-track assignment.
82+
83+ Parameters
84+ ----------
85+ measurements : array-like, shape (dim_meas, n_meas)
86+ Measurements for the current update step.
87+ measurement_matrix : array-like
88+ Linear measurement model.
89+ cov_mats_meas : array-like
90+ Measurement covariance matrix or per-measurement covariance tensor.
91+ warn_on_no_meas_for_track : bool, optional
92+ Whether to emit a warning when a track remains unassigned.
93+ pairwise_cost_matrix : array-like, optional
94+ Additional target/measurement association costs of shape
95+ ``(n_targets, n_meas)``. These costs are added to the geometric cost
96+ matrix before running the Hungarian algorithm.
97+ """
4198 n_targets = len (self .filter_bank )
4299 n_meas = measurements .shape [1 ]
43100
@@ -71,11 +128,14 @@ def find_association(
71128 )
72129
73130 if all_cov_mat_meas_equal and all_cov_mat_state_equal :
131+ shared_cov_mats_meas = (
132+ cov_mats_meas if cov_mats_meas .ndim == 2 else cov_mats_meas [:, :, 0 ]
133+ )
74134 curr_cov_mahalanobis = (
75135 measurement_matrix
76136 @ all_cov_mats_prior [:, :, 0 ]
77137 @ measurement_matrix .T
78- + cov_mats_meas [:, :, 0 ]
138+ + shared_cov_mats_meas
79139 )
80140 dists = cdist (
81141 (measurement_matrix @ all_means_prior ).T ,
@@ -84,6 +144,9 @@ def find_association(
84144 VI = curr_cov_mahalanobis ,
85145 )
86146 elif all_cov_mat_meas_equal :
147+ shared_cov_mats_meas = (
148+ cov_mats_meas if cov_mats_meas .ndim == 2 else cov_mats_meas [:, :, 0 ]
149+ )
87150 all_mats_mahalanobis = empty (
88151 (
89152 measurements .shape [0 ],
@@ -96,7 +159,7 @@ def find_association(
96159 measurement_matrix
97160 @ all_cov_mats_prior [:, :, i ]
98161 @ measurement_matrix .T
99- + cov_mats_meas
162+ + shared_cov_mats_meas
100163 )
101164 for i in range (n_targets ):
102165 dists [i , :] = cdist (
@@ -125,6 +188,11 @@ def find_association(
125188 else :
126189 raise ValueError ("Association scheme not recognized" )
127190
191+ pairwise_cost_matrix = self ._validate_pairwise_cost_matrix (
192+ pairwise_cost_matrix , n_targets , n_meas
193+ )
194+ dists = self ._apply_pairwise_cost_matrix (dists , pairwise_cost_matrix )
195+
128196 # Pad to square and add max_new_tracks rows and columns
129197 pad_to = max (n_targets , n_meas ) + self .association_param ["max_new_tracks" ]
130198 association_matrix = full (
@@ -137,9 +205,48 @@ def find_association(
137205
138206 association = col_ind [:n_targets ]
139207
140- if warn_on_no_meas_for_track and any (association > n_meas ):
208+ if warn_on_no_meas_for_track and any (association >= n_meas ):
141209 print (
142210 "GNN: No measurement was within gating threshold for at least one target."
143211 )
144212
145213 return association
214+
215+ def update_linear (
216+ self ,
217+ measurements ,
218+ measurement_matrix ,
219+ covMatsMeas ,
220+ pairwise_cost_matrix = None ,
221+ ):
222+ """Update the tracker with an optional additional association cost matrix."""
223+ assert (
224+ pyrecest .backend .__backend_name__ == "numpy"
225+ ), "Only supported for numpy backend"
226+ if len (self .filter_bank ) == 0 :
227+ warnings .warn ("Currently, there are zero targets" )
228+ return
229+ assert (
230+ measurement_matrix .shape [0 ] == measurements .shape [0 ]
231+ and measurement_matrix .shape [1 ]
232+ == self .filter_bank [0 ].get_point_estimate ().shape [0 ]
233+ ), (
234+ "Dimensions of measurement matrix must match state and measurement dimensions."
235+ )
236+ association = self .find_association (
237+ measurements ,
238+ measurement_matrix ,
239+ covMatsMeas ,
240+ pairwise_cost_matrix = pairwise_cost_matrix ,
241+ )
242+ currMeasCov = covMatsMeas
243+ for i in range (self .get_number_of_targets ()):
244+ if association [i ] < measurements .shape [1 ]:
245+ if covMatsMeas .ndim != 2 :
246+ currMeasCov = covMatsMeas [:, :, association [i ]]
247+ self .filter_bank [i ].update_linear (
248+ measurements [:, association [i ]], measurement_matrix , currMeasCov
249+ )
250+
251+ if self .log_posterior_estimates :
252+ self .store_posterior_estimates ()
0 commit comments