Skip to content

Commit 3e6a64f

Browse files
Added pairwise association-cost for GNN (#1739)
* Added pairwise association-cost for GNN * Fix pylint/flake8 issues: remove unused imports, add pylint disables Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/3594bc92-b3a4-4c14-b600-7754a7f5398d Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> * Suppress duplicate-code (R0801) in GNN filter and pairwise-cost test Agent-Logs-Url: https://github.com/FlorianPfaff/PyRecEst/sessions/37f5b6c8-c759-4657-8c8f-3a514ce88de3 Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: FlorianPfaff <6773539+FlorianPfaff@users.noreply.github.com>
1 parent fb5e8e7 commit 3e6a64f

2 files changed

Lines changed: 264 additions & 9 deletions

File tree

pyrecest/filters/global_nearest_neighbor.py

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
# pylint: disable=redefined-builtin,no-name-in-module,no-member
1+
# pylint: disable=redefined-builtin,no-name-in-module,no-member,duplicate-code
2+
import warnings
3+
4+
import pyrecest.backend
25
from pyrecest.backend import all, empty, full, repeat, squeeze, stack
36
from scipy.optimize import linear_sum_assignment
47
from scipy.spatial.distance import cdist
@@ -8,19 +11,36 @@
811

912

1013
class GlobalNearestNeighbor(AbstractNearestNeighborTracker):
14+
"""Global nearest-neighbor tracker for linear/Gaussian multitarget tracking.
15+
16+
Besides the built-in geometric association costs, this implementation can
17+
optionally fuse an externally computed ``pairwise_cost_matrix`` of shape
18+
``(n_targets, n_meas)``. This is useful for domains such as longitudinal
19+
calcium-imaging cell tracking where association should depend on arbitrary
20+
pairwise cues like ROI overlap, footprint correlation, or appearance
21+
embeddings in addition to centroid distance.
22+
"""
23+
1124
def __init__(
1225
self,
1326
initial_prior=None,
1427
association_param=None,
1528
log_prior_estimates=True,
1629
log_posterior_estimates=True,
1730
):
31+
default_association_param = {
32+
"distance_metric_pos": "Mahalanobis",
33+
"square_dist": True,
34+
"max_new_tracks": 10,
35+
"gating_distance_threshold": chi2.ppf(0.999, 2) ** 2,
36+
"pairwise_cost_weight": 1.0,
37+
}
1838
if association_param is None:
39+
association_param = default_association_param
40+
else:
1941
association_param = {
20-
"distance_metric_pos": "Mahalanobis",
21-
"square_dist": True,
22-
"max_new_tracks": 10,
23-
"gating_distance_threshold": chi2.ppf(0.999, 2) ** 2,
42+
**default_association_param,
43+
**association_param,
2444
}
2545

2646
super().__init__(
@@ -30,14 +50,51 @@ def __init__(
3050
log_posterior_estimates=log_posterior_estimates,
3151
)
3252

33-
# pylint: disable=too-many-locals
53+
@staticmethod
54+
def _validate_pairwise_cost_matrix(pairwise_cost_matrix, n_targets, n_meas):
55+
if pairwise_cost_matrix is None:
56+
return None
57+
if pairwise_cost_matrix.shape != (n_targets, n_meas):
58+
raise ValueError(
59+
"pairwise_cost_matrix must have shape "
60+
f"({n_targets}, {n_meas}), got {pairwise_cost_matrix.shape}."
61+
)
62+
return pairwise_cost_matrix
63+
64+
def _apply_pairwise_cost_matrix(self, dists, pairwise_cost_matrix):
65+
if pairwise_cost_matrix is None:
66+
return dists
67+
pairwise_cost_weight = self.association_param.get("pairwise_cost_weight", 1.0)
68+
if pairwise_cost_weight == 0.0:
69+
return dists
70+
return dists + pairwise_cost_weight * pairwise_cost_matrix
71+
72+
# pylint: disable=too-many-locals,too-many-positional-arguments
3473
def find_association(
3574
self,
3675
measurements,
3776
measurement_matrix,
3877
cov_mats_meas,
3978
warn_on_no_meas_for_track=True,
79+
pairwise_cost_matrix=None,
4080
):
81+
"""Find the minimum-cost measurement-to-track assignment.
82+
83+
Parameters
84+
----------
85+
measurements : array-like, shape (dim_meas, n_meas)
86+
Measurements for the current update step.
87+
measurement_matrix : array-like
88+
Linear measurement model.
89+
cov_mats_meas : array-like
90+
Measurement covariance matrix or per-measurement covariance tensor.
91+
warn_on_no_meas_for_track : bool, optional
92+
Whether to emit a warning when a track remains unassigned.
93+
pairwise_cost_matrix : array-like, optional
94+
Additional target/measurement association costs of shape
95+
``(n_targets, n_meas)``. These costs are added to the geometric cost
96+
matrix before running the Hungarian algorithm.
97+
"""
4198
n_targets = len(self.filter_bank)
4299
n_meas = measurements.shape[1]
43100

@@ -71,11 +128,14 @@ def find_association(
71128
)
72129

73130
if all_cov_mat_meas_equal and all_cov_mat_state_equal:
131+
shared_cov_mats_meas = (
132+
cov_mats_meas if cov_mats_meas.ndim == 2 else cov_mats_meas[:, :, 0]
133+
)
74134
curr_cov_mahalanobis = (
75135
measurement_matrix
76136
@ all_cov_mats_prior[:, :, 0]
77137
@ measurement_matrix.T
78-
+ cov_mats_meas[:, :, 0]
138+
+ shared_cov_mats_meas
79139
)
80140
dists = cdist(
81141
(measurement_matrix @ all_means_prior).T,
@@ -84,6 +144,9 @@ def find_association(
84144
VI=curr_cov_mahalanobis,
85145
)
86146
elif all_cov_mat_meas_equal:
147+
shared_cov_mats_meas = (
148+
cov_mats_meas if cov_mats_meas.ndim == 2 else cov_mats_meas[:, :, 0]
149+
)
87150
all_mats_mahalanobis = empty(
88151
(
89152
measurements.shape[0],
@@ -96,7 +159,7 @@ def find_association(
96159
measurement_matrix
97160
@ all_cov_mats_prior[:, :, i]
98161
@ measurement_matrix.T
99-
+ cov_mats_meas
162+
+ shared_cov_mats_meas
100163
)
101164
for i in range(n_targets):
102165
dists[i, :] = cdist(
@@ -125,6 +188,11 @@ def find_association(
125188
else:
126189
raise ValueError("Association scheme not recognized")
127190

191+
pairwise_cost_matrix = self._validate_pairwise_cost_matrix(
192+
pairwise_cost_matrix, n_targets, n_meas
193+
)
194+
dists = self._apply_pairwise_cost_matrix(dists, pairwise_cost_matrix)
195+
128196
# Pad to square and add max_new_tracks rows and columns
129197
pad_to = max(n_targets, n_meas) + self.association_param["max_new_tracks"]
130198
association_matrix = full(
@@ -137,9 +205,48 @@ def find_association(
137205

138206
association = col_ind[:n_targets]
139207

140-
if warn_on_no_meas_for_track and any(association > n_meas):
208+
if warn_on_no_meas_for_track and any(association >= n_meas):
141209
print(
142210
"GNN: No measurement was within gating threshold for at least one target."
143211
)
144212

145213
return association
214+
215+
def update_linear(
216+
self,
217+
measurements,
218+
measurement_matrix,
219+
covMatsMeas,
220+
pairwise_cost_matrix=None,
221+
):
222+
"""Update the tracker with an optional additional association cost matrix."""
223+
assert (
224+
pyrecest.backend.__backend_name__ == "numpy"
225+
), "Only supported for numpy backend"
226+
if len(self.filter_bank) == 0:
227+
warnings.warn("Currently, there are zero targets")
228+
return
229+
assert (
230+
measurement_matrix.shape[0] == measurements.shape[0]
231+
and measurement_matrix.shape[1]
232+
== self.filter_bank[0].get_point_estimate().shape[0]
233+
), (
234+
"Dimensions of measurement matrix must match state and measurement dimensions."
235+
)
236+
association = self.find_association(
237+
measurements,
238+
measurement_matrix,
239+
covMatsMeas,
240+
pairwise_cost_matrix=pairwise_cost_matrix,
241+
)
242+
currMeasCov = covMatsMeas
243+
for i in range(self.get_number_of_targets()):
244+
if association[i] < measurements.shape[1]:
245+
if covMatsMeas.ndim != 2:
246+
currMeasCov = covMatsMeas[:, :, association[i]]
247+
self.filter_bank[i].update_linear(
248+
measurements[:, association[i]], measurement_matrix, currMeasCov
249+
)
250+
251+
if self.log_posterior_estimates:
252+
self.store_posterior_estimates()
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import unittest
2+
3+
import numpy.testing as npt
4+
5+
# pylint: disable=no-member,duplicate-code
6+
import pyrecest.backend
7+
8+
from pyrecest.backend import array, column_stack, diag, eye, zeros
9+
from pyrecest.distributions import GaussianDistribution
10+
from pyrecest.filters import KalmanFilter
11+
from pyrecest.filters.global_nearest_neighbor import GlobalNearestNeighbor
12+
13+
14+
class GlobalNearestNeighborPairwiseCostTest(unittest.TestCase):
15+
"""Regression tests for pairwise-cost fusion in GNN."""
16+
17+
def setUp(self):
18+
self.kfs_init = [
19+
KalmanFilter(
20+
GaussianDistribution(zeros(4), diag(array([1.0, 2.0, 3.0, 4.0])))
21+
),
22+
KalmanFilter(
23+
GaussianDistribution(
24+
array([1.0, 2.0, 3.0, 4.0]), diag(array([2.0, 2.0, 2.0, 2.0]))
25+
)
26+
),
27+
KalmanFilter(
28+
GaussianDistribution(
29+
-array([1.0, 2.0, 3.0, 4.0]), diag(array([4.0, 3.0, 2.0, 1.0]))
30+
)
31+
),
32+
]
33+
self.meas_mat = array([[1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]])
34+
35+
@unittest.skipIf(
36+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
37+
reason="Not supported on this backend",
38+
)
39+
def test_association_supports_shared_2d_measurement_covariance(self):
40+
tracker = GlobalNearestNeighbor()
41+
tracker.filter_state = [
42+
KalmanFilter(
43+
GaussianDistribution(zeros(4), diag(array([1.0, 2.0, 3.0, 4.0])))
44+
),
45+
KalmanFilter(
46+
GaussianDistribution(
47+
array([1.0, 2.0, 3.0, 4.0]), diag(array([1.0, 2.0, 3.0, 4.0]))
48+
)
49+
),
50+
KalmanFilter(
51+
GaussianDistribution(
52+
-array([1.0, 2.0, 3.0, 4.0]), diag(array([1.0, 2.0, 3.0, 4.0]))
53+
)
54+
),
55+
]
56+
all_gaussians = [kf.filter_state for kf in tracker.filter_bank]
57+
perfect_meas_ordered = self.meas_mat @ column_stack(
58+
[gaussian.mu for gaussian in all_gaussians]
59+
)
60+
association = tracker.find_association(
61+
perfect_meas_ordered, self.meas_mat, eye(2)
62+
)
63+
npt.assert_array_equal(association, [0, 1, 2])
64+
65+
@unittest.skipIf(
66+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
67+
reason="Not supported on this backend",
68+
)
69+
def test_pairwise_cost_matrix_can_override_geometric_assignment(self):
70+
tracker = GlobalNearestNeighbor()
71+
tracker.filter_state = self.kfs_init
72+
all_gaussians = [kf.filter_state for kf in self.kfs_init]
73+
perfect_meas_ordered = self.meas_mat @ column_stack(
74+
[gaussian.mu for gaussian in all_gaussians]
75+
)
76+
77+
pairwise_cost_matrix = array(
78+
[[10.0, -10.0, 10.0], [10.0, 10.0, -10.0], [-10.0, 10.0, 10.0]]
79+
)
80+
association = tracker.find_association(
81+
perfect_meas_ordered,
82+
self.meas_mat,
83+
eye(2),
84+
pairwise_cost_matrix=pairwise_cost_matrix,
85+
)
86+
npt.assert_array_equal(association, [1, 2, 0])
87+
88+
@unittest.skipIf(
89+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
90+
reason="Not supported on this backend",
91+
)
92+
def test_pairwise_cost_matrix_shape_is_validated(self):
93+
tracker = GlobalNearestNeighbor()
94+
tracker.filter_state = self.kfs_init
95+
all_gaussians = [kf.filter_state for kf in self.kfs_init]
96+
perfect_meas_ordered = self.meas_mat @ column_stack(
97+
[gaussian.mu for gaussian in all_gaussians]
98+
)
99+
100+
with self.assertRaises(ValueError):
101+
tracker.find_association(
102+
perfect_meas_ordered,
103+
self.meas_mat,
104+
eye(2),
105+
pairwise_cost_matrix=zeros((2, 3)),
106+
)
107+
108+
@unittest.skipIf(
109+
pyrecest.backend.__backend_name__ in ("pytorch", "jax"),
110+
reason="Not supported on this backend",
111+
)
112+
def test_update_linear_accepts_pairwise_cost_matrix(self):
113+
tracker_manual = GlobalNearestNeighbor()
114+
tracker_pairwise = GlobalNearestNeighbor()
115+
tracker_manual.filter_state = self.kfs_init
116+
tracker_pairwise.filter_state = self.kfs_init
117+
118+
all_gaussians = [kf.filter_state for kf in self.kfs_init]
119+
perfect_meas_ordered = self.meas_mat @ column_stack(
120+
[gaussian.mu for gaussian in all_gaussians]
121+
)
122+
forced_permutation = [1, 2, 0]
123+
124+
for track_index, meas_index in enumerate(forced_permutation):
125+
tracker_manual.filter_bank[track_index].update_linear(
126+
perfect_meas_ordered[:, meas_index],
127+
self.meas_mat,
128+
eye(2),
129+
)
130+
131+
pairwise_cost_matrix = array(
132+
[[10.0, -10.0, 10.0], [10.0, 10.0, -10.0], [-10.0, 10.0, 10.0]]
133+
)
134+
tracker_pairwise.update_linear(
135+
perfect_meas_ordered,
136+
self.meas_mat,
137+
eye(2),
138+
pairwise_cost_matrix=pairwise_cost_matrix,
139+
)
140+
141+
npt.assert_allclose(
142+
tracker_pairwise.get_point_estimate(),
143+
tracker_manual.get_point_estimate(),
144+
)
145+
146+
147+
if __name__ == "__main__":
148+
unittest.main()

0 commit comments

Comments
 (0)