Skip to content

Commit f60b2d2

Browse files
committed
Next
1 parent 830ef75 commit f60b2d2

1 file changed

Lines changed: 12 additions & 11 deletions

File tree

pyrecest/filters/kernel_sme_filter.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import filterpy
1+
import bayesian_filters
22
import numpy as np
33
import scipy
44
from pyrecest.distributions import GaussianDistribution
@@ -8,6 +8,7 @@
88

99
from .abstract_multitarget_tracker import AbstractMultitargetTracker
1010

11+
from pyrecest.backend import zeros, hstack, mean, ones, eye
1112

1213
class KernelSMEFilter(AbstractMultitargetTracker):
1314
def __init__(
@@ -40,7 +41,7 @@ def filter_state(self, value: list[GaussianDistribution] | GaussianDistribution)
4041
self.n_targets = len(value)
4142
x_list = [prior.mu for prior in value]
4243
C_list = [prior.C for prior in value]
43-
self.x = np.hstack(x_list)
44+
self.x = hstack(x_list)
4445
self.C = scipy.linalg.block_diag(*C_list)
4546
if self.log_prior_estimates:
4647
self.store_prior_estimates()
@@ -118,7 +119,7 @@ def update_linear(
118119
if gating_threshold is None:
119120
gating_threshold = chi2.ppf(0.99, len(measurements))
120121
n_meas = measurements.shape[1]
121-
kernel_width = np.mean(np.diag(cov_mat_meas)) ** 2
122+
kernel_width = mean(np.diag(cov_mat_meas)) ** 2
122123

123124
if enable_gating:
124125
dists = np.full((self.n_targets, n_meas), np.nan)
@@ -181,10 +182,10 @@ def gen_test_points(measurements, kernel_width):
181182
meas_dim = measurements.shape[0]
182183
n_meas = measurements.shape[1]
183184

184-
# Use high level function from filterpy to generate sigma points
185+
# Use high level function from bayesian_filters to generate sigma points
185186
# Reorder samples to be consistent with the Matlab implementation
186187
all_sample_points_list = [
187-
filterpy.kalman.JulierSigmaPoints(meas_dim).sigma_points(
188+
bayesian_filters.kalman.JulierSigmaPoints(meas_dim).sigma_points(
188189
measurements[:, i], np.sqrt(kernel_width) * np.eye(meas_dim)
189190
)[
190191
[0, 3, 1, 4, 2], : # noqa: E203
@@ -199,12 +200,12 @@ def calc_pseudo_meas(testPoints, measurements, kernel_width):
199200
nMeas = measurements.shape[1]
200201
measDim = measurements.shape[0]
201202
nTestPoints = testPoints.shape[1]
202-
pseudoMeas = np.zeros(nTestPoints)
203+
pseudoMeas = zeros(nTestPoints)
203204
for i in range(nTestPoints):
204205
a_i = testPoints[:, i]
205206
for j in range(nMeas):
206207
pseudoMeas[i] += multivariate_normal.pdf(
207-
a_i, mean=measurements[:, j], cov=kernel_width * np.eye(measDim)
208+
a_i, mean=measurements[:, j], cov=kernel_width * eye(measDim)
208209
)
209210
return pseudoMeas
210211

@@ -238,7 +239,7 @@ def calc_moments(
238239

239240
# Ensure lambdaMultimeas is a vector of length n_targets
240241
if isinstance(lambdaMultimeas, (float, int)):
241-
lambda_vec = float(lambdaMultimeas) * np.ones(n_targets)
242+
lambda_vec = float(lambdaMultimeas) * ones(n_targets)
242243
else:
243244
lambda_vec = np.asarray(lambdaMultimeas, dtype=float).reshape(-1)
244245
assert lambda_vec.size == n_targets, "lambdaMultimeas must be scalar or length n_targets"
@@ -255,7 +256,7 @@ def calc_moments(
255256
]
256257
x_prior_mat = np.reshape(x_prior, (state_dim, n_targets), order="F")
257258

258-
I_meas = np.eye(meas_dim)
259+
I_meas = eye(meas_dim)
259260

260261
# Per-target predicted measurement means and covariances
261262
meas_cov_kernel = [] # H C_l H^T + R + Gamma I
@@ -288,7 +289,7 @@ def calc_moments(
288289
for i in range(n_testpoints):
289290
clutter_pdf[i] = multivariate_normal.pdf(
290291
testPoints[:, i],
291-
mean=np.zeros(meas_dim),
292+
mean=nzeros(meas_dim),
292293
cov=clutter_cov_kernel,
293294
)
294295

@@ -361,7 +362,7 @@ def calc_moments(
361362
sigma_s[ju, iu] = sigma_s[iu, ju]
362363

363364
# Cross-covariance Sigma_xs
364-
sigma_xs = np.zeros((x_prior.shape[0], n_testpoints))
365+
sigma_xs = zeros((x_prior.shape[0], n_testpoints))
365366

366367
# Block columns of C_prior (for K^l)
367368
C_k_full_columns = np.hsplit(C_prior, n_targets)

0 commit comments

Comments
 (0)