Skip to content

Commit dcdb4ef

Browse files
committed
Added cart prod pf
1 parent 929f181 commit dcdb4ef

4 files changed

Lines changed: 255 additions & 0 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from ..hypersphere_subset.abstract_hyperhemispherical_distribution import (
2+
AbstractHyperhemisphericalDistribution,
3+
)
4+
from ..hypersphere_subset.abstract_hypersphere_subset_dirac_distribution import (
5+
AbstractHypersphereSubsetDiracDistribution,
6+
)
7+
from ..hypersphere_subset.abstract_hyperspherical_distribution import (
8+
AbstractHypersphericalDistribution,
9+
)
10+
11+
12+
class HyperhemisphericalDiracDistribution(
13+
AbstractHypersphereSubsetDiracDistribution, AbstractHyperhemisphericalDistribution
14+
):
15+
pass
16+
17+
18+
class HypersphericalDiracDistribution(
19+
AbstractHypersphereSubsetDiracDistribution, AbstractHypersphericalDistribution
20+
):
21+
pass
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import copy
2+
from collections.abc import Callable
3+
4+
from ..abstract_dirac_distribution import AbstractDiracDistribution
5+
6+
7+
class HyperhemisphereCartProdDiracDistribution(AbstractDiracDistribution):
8+
def __init__(self, d, w, dim_hemisphere, n_hemispheres):
9+
"""
10+
Initialize a Dirac distribution with given Dirac locations and weights.
11+
12+
:param d: Dirac locations as a numpy array.
13+
:param w: Weights of Dirac locations as a numpy array. If not provided, defaults to uniform weights.
14+
"""
15+
super().__init__(d, w)
16+
self.dim_hemisphere = dim_hemisphere
17+
self.n_hemispheres = n_hemispheres
18+
assert self.d.shape[-1] == (
19+
(1 + dim_hemisphere) * n_hemispheres
20+
), "Dimension is not correct."
21+
self.dim = dim_hemisphere * n_hemispheres
22+
23+
def apply_function_component_wise(
24+
self, f: Callable, f_supports_multiple: bool = True
25+
):
26+
"""
27+
Apply a function to the Dirac locations and return a new distribution.
28+
29+
:param f: Function to apply.
30+
:returns: A new distribution with the function applied to the locations.
31+
"""
32+
assert f_supports_multiple, "Function must support multiple inputs."
33+
dist = copy.deepcopy(self)
34+
for i in range(self.n_hemispheres):
35+
dist.d[
36+
i * self.dim_hemisphere : (i + 1) * self.dim_hemisphere # noqa: E203
37+
] = f(self.d[i])
38+
return dist
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from collections.abc import Callable
2+
3+
import numpy as np
4+
5+
# pylint: disable=no-name-in-module,no-member
6+
from pyrecest.backend import ones # noqa: F821
7+
from pyrecest.distributions import AbstractHypersphericalDistribution
8+
from pyrecest.distributions.cart_prod.hyperhemisphere_cart_prod_dirac_distribution import (
9+
HyperhemisphereCartProdDiracDistribution,
10+
)
11+
from pyrecest.distributions.hypersphere_subset.abstract_hyperhemispherical_distribution import (
12+
AbstractHyperhemisphericalDistribution,
13+
)
14+
from pyrecest.distributions.hypersphere_subset.hyperhemispherical_watson_distribution import (
15+
HyperhemisphericalWatsonDistribution,
16+
)
17+
18+
from .abstract_particle_filter import AbstractParticleFilter
19+
from beartype import beartype
20+
21+
class HyperhemisphereCartProdParticleFilter(AbstractParticleFilter):
22+
def __init__(
23+
self, n_particles: int | np.int32 | np.int64, dim_hemisphere, n_hemispheres
24+
) -> None:
25+
"""
26+
Constructor
27+
28+
Parameters:
29+
n_particles (int > 0): Number of particles to use
30+
dim (int > 0): Dimension
31+
"""
32+
initial_filter_state = HyperhemisphereCartProdDiracDistribution(
33+
np.empty((n_particles, (dim_hemisphere + 1) * n_hemispheres)),
34+
ones(n_particles) / n_particles,
35+
dim_hemisphere,
36+
n_hemispheres,
37+
)
38+
AbstractParticleFilter.__init__(self, initial_filter_state=initial_filter_state)
39+
40+
def set_state(self, new_state):
41+
"""
42+
Sets the current system state
43+
44+
Parameters:
45+
dist_ (HyperhemisphericalDiracDistribution): New state
46+
"""
47+
assert isinstance(new_state, AbstractHyperhemisphericalDistribution)
48+
if not isinstance(new_state, HyperhemisphereCartProdDiracDistribution):
49+
new_state = HyperhemisphereCartProdDiracDistribution(
50+
new_state.sample(self.filter_state.d.shape[0]),
51+
w=ones(self.filter_state.d.shape[0]) / self.filter_state.d.shape[0],
52+
dim_hemisphere=self.filter_state.dim_hemisphere,
53+
n_hemispheres=self.filter_state.n_hemispheres,
54+
)
55+
self.filter_state = new_state
56+
57+
@beartype
58+
def predict_nonlinear_each_part(
59+
self,
60+
f,
61+
# Limit noise distribution to ones that support set_mode
62+
noise_distribution: HyperhemisphericalWatsonDistribution | None = None,
63+
function_is_vectorized: bool = True,
64+
shift_instead_of_add: bool = True,
65+
):
66+
"""
67+
Predicts the next state for each hyperhemisphere
68+
"""
69+
assert function_is_vectorized, "Only vectorized functions are supported"
70+
assert (
71+
noise_distribution is None or noise_distribution.dim == self.filter_state.dim_hemisphere
72+
), "Noise dimension must match state dimension in Cartesian product"
73+
assert shift_instead_of_add, "Only shifting is supported"
74+
for i in range(self.filter_state.n_hemispheres):
75+
# Apply the function to each hyperhemisphere
76+
index_arr = range(
77+
i * (self.filter_state.dim_hemisphere + 1),
78+
(i + 1) * (self.filter_state.dim_hemisphere + 1),
79+
)
80+
# Consider only part of the state of the current hemisphere
81+
curr_d = self.filter_state.d[:, index_arr]
82+
d_fun_applied = f(curr_d)
83+
# Add noise
84+
if noise_distribution is not None:
85+
for j in range(self.filter_state.d.shape[0]):
86+
# Set mean to transformed state
87+
# This will fail if set_mean is unavailable
88+
noise_distribution.set_mode(d_fun_applied[j])
89+
# Sample one noise vector centered at the transformed state to add noise
90+
self.filter_state.d[j, index_arr] = noise_distribution.sample(1)
91+
92+
@property
93+
def filter_state(self):
94+
return self._filter_state
95+
96+
@filter_state.setter
97+
def filter_state(self, new_state):
98+
if isinstance(
99+
new_state,
100+
(
101+
AbstractHyperhemisphericalDistribution,
102+
AbstractHypersphericalDistribution,
103+
),
104+
):
105+
assert new_state.dim == self.filter_state.dim_hemisphere
106+
samples = new_state.sample(
107+
self._filter_state.d.shape[0] * self._filter_state.n_hemispheres
108+
)
109+
if isinstance(new_state, AbstractHypersphericalDistribution):
110+
samples[samples[:, -1] < 0] = -samples[samples[:, -1] < 0]
111+
self._filter_state.d = samples.reshape(self.filter_state.d.shape)
112+
else:
113+
AbstractParticleFilter.filter_state.fset(self, new_state)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import unittest
2+
3+
import pyrecest.backend # pylint: disable=no-name-in-module,no-member
4+
from pyrecest.backend import abs, array, sum # pylint: disable=no-name-in-module,no-member,redefined-builtin
5+
from pyrecest.distributions import VonMisesFisherDistribution
6+
from pyrecest.filters.hyperhemisphere_cart_prod_particle_filter import (
7+
HyperhemisphereCartProdParticleFilter,
8+
)
9+
10+
11+
class HyperHemisphereCartProdParticleFilterTest(unittest.TestCase):
12+
def test_init(self):
13+
n_particles = 1000
14+
dim_hemisphere = 3
15+
n_hemispheres = 2
16+
pf = HyperhemisphereCartProdParticleFilter(
17+
n_particles, dim_hemisphere, n_hemispheres
18+
)
19+
self.assertEqual(
20+
pf.filter_state.d.shape,
21+
(n_particles, (dim_hemisphere + 1) * n_hemispheres),
22+
)
23+
24+
@unittest.skipIf(
25+
pyrecest.backend.__name__ in ("pyrecest.jax", "pyrecest.pytorch"), # pylint: disable=no-name-in-module,no-member
26+
reason="Backend not supported'",
27+
)
28+
def test_set_state(self):
29+
n_particles = 1000
30+
dim_hemisphere = 3
31+
n_hemispheres = 2
32+
pf = HyperhemisphereCartProdParticleFilter(
33+
n_particles, dim_hemisphere, n_hemispheres
34+
)
35+
self.assertEqual(
36+
pf.filter_state.d.shape,
37+
(n_particles, (dim_hemisphere + 1) * n_hemispheres),
38+
)
39+
pf.filter_state = VonMisesFisherDistribution(array([0.0, 0.0, 0.0, 1.0]), 1.0)
40+
41+
@unittest.skipIf(
42+
pyrecest.backend.__name__ in ("pyrecest.jax", "pyrecest.pytorch"), # pylint: disable=no-name-in-module,no-member
43+
reason="Backend not supported'",
44+
)
45+
def test_predict(self):
46+
n_particles = 1000
47+
dim_hemisphere = 3
48+
n_hemispheres = 2
49+
pf = HyperhemisphereCartProdParticleFilter(
50+
n_particles, dim_hemisphere, n_hemispheres
51+
)
52+
pf.filter_state = VonMisesFisherDistribution(array([0.0, 0.0, 0.0, 1.0]), 1.0)
53+
54+
noise_distribution = VonMisesFisherDistribution(
55+
array([0.0, 0.0, 0.0, 1.0]), 1.0
56+
)
57+
58+
def identity_function(x):
59+
return x
60+
61+
pf.predict_nonlinear_each_part(identity_function, noise_distribution)
62+
63+
@unittest.skipIf(
64+
pyrecest.backend.__name__ in ("pyrecest.jax", "pyrecest.pytorch"), # pylint: disable=no-name-in-module,no-member
65+
reason="Backend not supported'",
66+
)
67+
def test_update(self):
68+
n_particles = 1000
69+
dim_hemisphere = 3
70+
n_hemispheres = 2
71+
pf = HyperhemisphereCartProdParticleFilter(
72+
n_particles, dim_hemisphere, n_hemispheres
73+
)
74+
pf.filter_state = VonMisesFisherDistribution(array([0.0, 0.0, 0.0, 1.0]), 1.0)
75+
76+
def likelihood_function(x):
77+
return abs(sum(x, axis=1)) # noqa: A001
78+
79+
pf.update_nonlinear_using_likelihood(likelihood_function)
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

0 commit comments

Comments
 (0)