Skip to content

Commit 590b446

Browse files
committed
Next
1 parent 830ef75 commit 590b446

2 files changed

Lines changed: 136 additions & 141 deletions

File tree

pyrecest/filters/kernel_sme_filter.py

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import filterpy
1+
import bayesian_filters
22
import numpy as np
33
import scipy
4+
from pyrecest.backend import eye, hstack, mean, ones, zeros
45
from pyrecest.distributions import GaussianDistribution
56
from scipy.linalg import block_diag
67
from scipy.spatial.distance import cdist
@@ -40,7 +41,7 @@ def filter_state(self, value: list[GaussianDistribution] | GaussianDistribution)
4041
self.n_targets = len(value)
4142
x_list = [prior.mu for prior in value]
4243
C_list = [prior.C for prior in value]
43-
self.x = np.hstack(x_list)
44+
self.x = hstack(x_list)
4445
self.C = scipy.linalg.block_diag(*C_list)
4546
if self.log_prior_estimates:
4647
self.store_prior_estimates()
@@ -118,7 +119,7 @@ def update_linear(
118119
if gating_threshold is None:
119120
gating_threshold = chi2.ppf(0.99, len(measurements))
120121
n_meas = measurements.shape[1]
121-
kernel_width = np.mean(np.diag(cov_mat_meas)) ** 2
122+
kernel_width = mean(np.diag(cov_mat_meas)) ** 2
122123

123124
if enable_gating:
124125
dists = np.full((self.n_targets, n_meas), np.nan)
@@ -181,10 +182,10 @@ def gen_test_points(measurements, kernel_width):
181182
meas_dim = measurements.shape[0]
182183
n_meas = measurements.shape[1]
183184

184-
# Use high level function from filterpy to generate sigma points
185+
# Use high level function from bayesian_filters to generate sigma points
185186
# Reorder samples to be consistent with the Matlab implementation
186187
all_sample_points_list = [
187-
filterpy.kalman.JulierSigmaPoints(meas_dim).sigma_points(
188+
bayesian_filters.kalman.JulierSigmaPoints(meas_dim).sigma_points(
188189
measurements[:, i], np.sqrt(kernel_width) * np.eye(meas_dim)
189190
)[
190191
[0, 3, 1, 4, 2], : # noqa: E203
@@ -199,12 +200,12 @@ def calc_pseudo_meas(testPoints, measurements, kernel_width):
199200
nMeas = measurements.shape[1]
200201
measDim = measurements.shape[0]
201202
nTestPoints = testPoints.shape[1]
202-
pseudoMeas = np.zeros(nTestPoints)
203+
pseudoMeas = zeros(nTestPoints)
203204
for i in range(nTestPoints):
204205
a_i = testPoints[:, i]
205206
for j in range(nMeas):
206207
pseudoMeas[i] += multivariate_normal.pdf(
207-
a_i, mean=measurements[:, j], cov=kernel_width * np.eye(measDim)
208+
a_i, mean=measurements[:, j], cov=kernel_width * eye(measDim)
208209
)
209210
return pseudoMeas
210211

@@ -238,35 +239,36 @@ def calc_moments(
238239

239240
# Ensure lambdaMultimeas is a vector of length n_targets
240241
if isinstance(lambdaMultimeas, (float, int)):
241-
lambda_vec = float(lambdaMultimeas) * np.ones(n_targets)
242+
lambda_vec = float(lambdaMultimeas) * ones(n_targets)
242243
else:
243244
lambda_vec = np.asarray(lambdaMultimeas, dtype=float).reshape(-1)
244-
assert lambda_vec.size == n_targets, "lambdaMultimeas must be scalar or length n_targets"
245+
assert (
246+
lambda_vec.size == n_targets
247+
), "lambdaMultimeas must be scalar or length n_targets"
245248

246249
lam_c = float(falseAlarmRate)
247250

248251
state_dim = x_prior.size // n_targets
249252

250253
# Per-target state blocks and block covariances
251254
C_blocks_list = [
252-
C_prior[i * state_dim : (i + 1) * state_dim,
253-
i * state_dim : (i + 1) * state_dim]
255+
C_prior[
256+
i * state_dim : (i + 1) * state_dim, i * state_dim : (i + 1) * state_dim
257+
]
254258
for i in range(n_targets)
255259
]
256260
x_prior_mat = np.reshape(x_prior, (state_dim, n_targets), order="F")
257261

258-
I_meas = np.eye(meas_dim)
262+
I_meas = eye(meas_dim)
259263

260264
# Per-target predicted measurement means and covariances
261-
meas_cov_kernel = [] # H C_l H^T + R + Gamma I
265+
meas_cov_kernel = [] # H C_l H^T + R + Gamma I
262266
meas_cov_half_kernel = [] # H C_l H^T + R + 0.5 Gamma I
263267
meas_means = []
264268

265269
for k in range(n_targets):
266270
S = (
267-
measurement_matrix
268-
@ C_blocks_list[k]
269-
@ measurement_matrix.T
271+
measurement_matrix @ C_blocks_list[k] @ measurement_matrix.T
270272
+ covMatMeas
271273
)
272274
meas_cov_kernel.append(S + kernel_width * I_meas)
@@ -288,7 +290,7 @@ def calc_moments(
288290
for i in range(n_testpoints):
289291
clutter_pdf[i] = multivariate_normal.pdf(
290292
testPoints[:, i],
291-
mean=np.zeros(meas_dim),
293+
mean=nzeros(meas_dim),
292294
cov=clutter_cov_kernel,
293295
)
294296

@@ -331,9 +333,7 @@ def calc_moments(
331333
)
332334

333335
# --- term 2: kernel_between * sum_l (lambda_l^2 P_l^{Gamma/2}(midpoint)) ---
334-
term2 = kernel_between * np.sum(
335-
(lambda_vec ** 2) * Pij[i, j, :]
336-
)
336+
term2 = kernel_between * np.sum((lambda_vec**2) * Pij[i, j, :])
337337

338338
# --- clutter-related terms (3)–(6) ---
339339
clutter_i = clutter_pdf[i]
@@ -345,39 +345,34 @@ def calc_moments(
345345
cov=clutter_cov_kernel,
346346
)
347347

348-
term3 = (lam_c ** 2) * clutter_i * clutter_j
348+
term3 = (lam_c**2) * clutter_i * clutter_j
349349
term4 = mu_s[i] * lam_c * clutter_j
350350
term5 = mu_s[j] * lam_c * clutter_i
351351
term6 = lam_c * kernel_between * clutter_mid_pdf
352352

353353
# Full Sigma^{s_i, s_j}
354354
sigma_s[i, j] = (
355-
term1 + term2 + term3 + term4 + term5 + term6
356-
- mu_s[i] * mu_s[j]
355+
term1 + term2 + term3 + term4 + term5 + term6 - mu_s[i] * mu_s[j]
357356
)
358357

359358
# Symmetrize Sigma_s
360359
iu, ju = np.triu_indices(n_testpoints, k=1)
361360
sigma_s[ju, iu] = sigma_s[iu, ju]
362361

363362
# Cross-covariance Sigma_xs
364-
sigma_xs = np.zeros((x_prior.shape[0], n_testpoints))
363+
sigma_xs = zeros((x_prior.shape[0], n_testpoints))
365364

366365
# Block columns of C_prior (for K^l)
367366
C_k_full_columns = np.hsplit(C_prior, n_targets)
368367
K_list = []
369368
for k in range(n_targets):
370369
S_full = (
371-
measurement_matrix
372-
@ C_blocks_list[k]
373-
@ measurement_matrix.T
370+
measurement_matrix @ C_blocks_list[k] @ measurement_matrix.T
374371
+ covMatMeas
375372
+ kernel_width * I_meas
376373
)
377374
K_list.append(
378-
C_k_full_columns[k]
379-
@ measurement_matrix.T
380-
@ np.linalg.inv(S_full)
375+
C_k_full_columns[k] @ measurement_matrix.T @ np.linalg.inv(S_full)
381376
)
382377

383378
# Sigma^{x, s_i}
@@ -390,8 +385,7 @@ def calc_moments(
390385
* P_target[i, k]
391386
* (
392387
x_prior
393-
+ K_list[k]
394-
@ (z - measurement_matrix @ x_prior_mat[:, k])
388+
+ K_list[k] @ (z - measurement_matrix @ x_prior_mat[:, k])
395389
)
396390
)
397391
# clutter term + centering

0 commit comments

Comments
 (0)