|
| 1 | +from collections.abc import Callable |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +# pylint: disable=no-name-in-module,no-member |
| 6 | +from pyrecest.backend import ones # noqa: F821 |
| 7 | +from pyrecest.distributions import AbstractHypersphericalDistribution |
| 8 | +from pyrecest.distributions.cart_prod.hyperhemisphere_cart_prod_dirac_distribution import ( |
| 9 | + HyperhemisphereCartProdDiracDistribution, |
| 10 | +) |
| 11 | +from pyrecest.distributions.hypersphere_subset.abstract_hyperhemispherical_distribution import ( |
| 12 | + AbstractHyperhemisphericalDistribution, |
| 13 | +) |
| 14 | +from pyrecest.distributions.hypersphere_subset.hyperhemispherical_watson_distribution import ( |
| 15 | + HyperhemisphericalWatsonDistribution, |
| 16 | +) |
| 17 | + |
| 18 | +from .abstract_particle_filter import AbstractParticleFilter |
| 19 | +from beartype import beartype |
| 20 | + |
| 21 | +class HyperhemisphereCartProdParticleFilter(AbstractParticleFilter): |
| 22 | + def __init__( |
| 23 | + self, n_particles: int | np.int32 | np.int64, dim_hemisphere, n_hemispheres |
| 24 | + ) -> None: |
| 25 | + """ |
| 26 | + Constructor |
| 27 | +
|
| 28 | + Parameters: |
| 29 | + n_particles (int > 0): Number of particles to use |
| 30 | + dim (int > 0): Dimension |
| 31 | + """ |
| 32 | + initial_filter_state = HyperhemisphereCartProdDiracDistribution( |
| 33 | + np.empty((n_particles, (dim_hemisphere + 1) * n_hemispheres)), |
| 34 | + ones(n_particles) / n_particles, |
| 35 | + dim_hemisphere, |
| 36 | + n_hemispheres, |
| 37 | + ) |
| 38 | + AbstractParticleFilter.__init__(self, initial_filter_state=initial_filter_state) |
| 39 | + |
| 40 | + def set_state(self, new_state): |
| 41 | + """ |
| 42 | + Sets the current system state |
| 43 | +
|
| 44 | + Parameters: |
| 45 | + dist_ (HyperhemisphericalDiracDistribution): New state |
| 46 | + """ |
| 47 | + assert isinstance(new_state, AbstractHyperhemisphericalDistribution) |
| 48 | + if not isinstance(new_state, HyperhemisphereCartProdDiracDistribution): |
| 49 | + new_state = HyperhemisphereCartProdDiracDistribution( |
| 50 | + new_state.sample(self.filter_state.d.shape[0]), |
| 51 | + w=ones(self.filter_state.d.shape[0]) / self.filter_state.d.shape[0], |
| 52 | + dim_hemisphere=self.filter_state.dim_hemisphere, |
| 53 | + n_hemispheres=self.filter_state.n_hemispheres, |
| 54 | + ) |
| 55 | + self.filter_state = new_state |
| 56 | + |
| 57 | + @beartype |
| 58 | + def predict_nonlinear_each_part( |
| 59 | + self, |
| 60 | + f, |
| 61 | + # Limit noise distribution to ones that support set_mode |
| 62 | + noise_distribution: HyperhemisphericalWatsonDistribution | None = None, |
| 63 | + function_is_vectorized: bool = True, |
| 64 | + shift_instead_of_add: bool = True, |
| 65 | + ): |
| 66 | + """ |
| 67 | + Predicts the next state for each hyperhemisphere |
| 68 | + """ |
| 69 | + assert function_is_vectorized, "Only vectorized functions are supported" |
| 70 | + assert ( |
| 71 | + noise_distribution is None or noise_distribution.dim == self.filter_state.dim_hemisphere |
| 72 | + ), "Noise dimension must match state dimension in Cartesian product" |
| 73 | + assert shift_instead_of_add, "Only shifting is supported" |
| 74 | + for i in range(self.filter_state.n_hemispheres): |
| 75 | + # Apply the function to each hyperhemisphere |
| 76 | + index_arr = range( |
| 77 | + i * (self.filter_state.dim_hemisphere + 1), |
| 78 | + (i + 1) * (self.filter_state.dim_hemisphere + 1), |
| 79 | + ) |
| 80 | + # Consider only part of the state of the current hemisphere |
| 81 | + curr_d = self.filter_state.d[:, index_arr] |
| 82 | + d_fun_applied = f(curr_d) |
| 83 | + # Add noise |
| 84 | + if noise_distribution is not None: |
| 85 | + for j in range(self.filter_state.d.shape[0]): |
| 86 | + # Set mean to transformed state |
| 87 | + # This will fail if set_mean is unavailable |
| 88 | + noise_distribution.set_mode(d_fun_applied[j]) |
| 89 | + # Sample one noise vector centered at the transformed state to add noise |
| 90 | + self.filter_state.d[j, index_arr] = noise_distribution.sample(1) |
| 91 | + |
| 92 | + @property |
| 93 | + def filter_state(self): |
| 94 | + return self._filter_state |
| 95 | + |
| 96 | + @filter_state.setter |
| 97 | + def filter_state(self, new_state): |
| 98 | + if isinstance( |
| 99 | + new_state, |
| 100 | + ( |
| 101 | + AbstractHyperhemisphericalDistribution, |
| 102 | + AbstractHypersphericalDistribution, |
| 103 | + ), |
| 104 | + ): |
| 105 | + assert new_state.dim == self.filter_state.dim_hemisphere |
| 106 | + samples = new_state.sample( |
| 107 | + self._filter_state.d.shape[0] * self._filter_state.n_hemispheres |
| 108 | + ) |
| 109 | + if isinstance(new_state, AbstractHypersphericalDistribution): |
| 110 | + samples[samples[:, -1] < 0] = -samples[samples[:, -1] < 0] |
| 111 | + self._filter_state.d = samples.reshape(self.filter_state.d.shape) |
| 112 | + else: |
| 113 | + AbstractParticleFilter.filter_state.fset(self, new_state) |
0 commit comments