Skip to content

Commit 5d4a087

Browse files
committed
Added CircularGridDistribution
1 parent 8dbcb1c commit 5d4a087

2 files changed

Lines changed: 261 additions & 0 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from .abstract_circular_distribution import AbstractCircularDistribution
2+
from .circular_fourier_distribution import CircularFourierDistribution
3+
4+
from pyrecest.backend import array, pi, where, sin, linspace, ceil, floor, arange, sqrt, sum, maximum, mod, round, isclose, \
5+
import numpy as np
6+
import matplotlib.pyplot as plt
7+
import warnings
8+
9+
class CircularGridDistribution(AbstractCircularDistribution, HypertoroidalGridDistribution):
10+
"""
11+
Density representation using function values on a grid with Fourier interpolation.
12+
"""
13+
14+
def __init__(self, gridValues, enforcePdfNonnegative=True):
15+
gridValues = array(gridValues)
16+
# Check if gridValues is already a distribution (in which case use fromDistribution)
17+
if isinstance(gridValues, AbstractCircularDistribution):
18+
raise ValueError("You gave a distribution as the first argument. "
19+
"To convert distributions to a distribution in grid representation, use fromDistribution.")
20+
# Call parent constructor (HypertoroidalGridDistribution)
21+
HypertoroidalGridDistribution.__init__(self, None, gridValues, enforcePdfNonnegative)
22+
self.gridValues = gridValues
23+
self.enforcePdfNonnegative = enforcePdfNonnegative
24+
25+
def pdf(self, xs, useSinc=False, sincRepetitions=5):
26+
xs = array(xs)
27+
if useSinc:
28+
if sincRepetitions % 2 != 1:
29+
raise ValueError("sincRepetitions must be an odd integer.")
30+
N = len(self.gridValues)
31+
step_size = 2 * pi / N
32+
33+
# Define MATLAB-style sinc: sinc(x) = sin(x)/x with sinc(0)=1
34+
def matlab_sinc(x):
35+
return where(x == 0, 1.0, sin(x) / x)
36+
37+
# Create the range vector: from -floor(sincRepetitions/2)*N to ceil(sincRepetitions/2)*N - 1
38+
lower = int(floor(sincRepetitions / 2) * N)
39+
upper = int(ceil(sincRepetitions / 2) * N)
40+
r = arange(-lower, upper)
41+
42+
# Compute the sinc values. (xs/step_size) becomes a column vector.
43+
sincVals = matlab_sinc((xs / step_size)[:, None] - r[None, :])
44+
# Tile the grid values; note that MATLAB’s repmat(gridValues', [1, sincRepetitions])
45+
if self.enforcePdfNonnegative:
46+
coeffs = np.tile(sqrt(self.gridValues), sincRepetitions)
47+
p = sum(coeffs * sincVals, axis=1) ** 2
48+
else:
49+
coeffs = np.tile(self.gridValues, sincRepetitions)
50+
p = sum(coeffs * sincVals, axis=1)
51+
return p
52+
else:
53+
N = len(self.gridValues)
54+
noCoeffs = N
55+
if N % 2 == 0:
56+
noCoeffs += 1 # extra coefficient as in the MATLAB code
57+
transform = 'sqrt' if self.enforcePdfNonnegative else 'identity'
58+
fd = CircularFourierDistribution.fromFunctionValues(self.gridValues, noCoeffs, transform)
59+
return fd.pdf(xs)
60+
61+
def pdfOnGrid(self, noOfDesiredGridpoints):
62+
N = len(self.gridValues)
63+
xGrid = np.linspace(0, 2*np.pi, N, endpoint=False)
64+
step = N / noOfDesiredGridpoints
65+
if not float(step).is_integer():
66+
raise ValueError("Number of function values must be a multiple of noOfDesiredGridpoints")
67+
step = int(step)
68+
vals = self.gridValues[::step]
69+
return vals, xGrid
70+
71+
def trigonometricMoment(self, n):
72+
N = len(self.gridValues)
73+
noCoeffs = N
74+
if N % 2 == 0:
75+
noCoeffs += 1
76+
# In MATLAB a warning is suppressed here; in Python you might use warnings.filterwarnings.
77+
fd = CircularFourierDistribution.fromFunctionValues(self.gridValues, noCoeffs, 'identity')
78+
return fd.trigonometricMoment(n)
79+
80+
def plot(self, *args, **kwargs):
81+
N = len(self.gridValues)
82+
gridPoints = np.linspace(0, 2*np.pi, N, endpoint=False)
83+
# Optionally, call a parent plot if available.
84+
super().plot(*args, **kwargs)
85+
# Then also plot the grid points as markers.
86+
plt.plot(gridPoints, self.gridValues, 'x')
87+
plt.xlabel("Angle")
88+
plt.ylabel("Grid Values")
89+
plt.show()
90+
# Returning the list of line handles is optional in Python.
91+
return
92+
93+
def value(self, xa):
94+
"""
95+
Evaluate the density at points when interpreted as a probability mass function.
96+
This implementation assumes that the grid points are used as indicators.
97+
"""
98+
xa = np.array(xa)
99+
grid = self.getGrid()
100+
# Find exact matches (using np.isclose for floating-point comparisons)
101+
result = np.zeros_like(xa)
102+
for i, angle in enumerate(xa):
103+
match = np.where(np.isclose(grid, angle))[0]
104+
result[i] = self.gridValues[match[0]] if match.size > 0 else 0
105+
return result
106+
107+
def getGrid(self):
108+
N = len(self.gridValues)
109+
return np.linspace(0, 2*np.pi, N, endpoint=False)
110+
111+
def getGridPoint(self, indices=None):
112+
"""
113+
If indices is None, returns the entire grid.
114+
Otherwise, returns grid points corresponding to (indices - 1)*2*pi/N.
115+
(MATLAB indices are 1-based; here we mimic that conversion.)
116+
"""
117+
N = len(self.gridValues)
118+
if indices is None:
119+
return self.getGrid()
120+
else:
121+
indices = np.array(indices)
122+
return (indices - 1) * (2 * np.pi / N)
123+
124+
def convolve(self, f2):
125+
if self.enforcePdfNonnegative != f2.enforcePdfNonnegative:
126+
raise ValueError("Mismatch in enforcePdfNonnegative between the two distributions.")
127+
if len(self.gridValues) != len(f2.gridValues):
128+
raise ValueError("Grid sizes must be identical for convolution.")
129+
N = len(self.gridValues)
130+
# Circular convolution using FFTs
131+
fft1 = np.fft.fft(self.gridValues)
132+
fft2 = np.fft.fft(f2.gridValues)
133+
convResult = np.real(np.fft.ifft(fft1 * fft2)) * (2 * np.pi / N)
134+
convResult[convResult < 0] = 0 # Remove small negative values due to numerical error
135+
return CircularGridDistribution(convResult, self.enforcePdfNonnegative)
136+
137+
def truncate(self, noOfGridpoints):
138+
N = len(self.gridValues)
139+
if noOfGridpoints - 1 <= 0:
140+
raise ValueError("Number of coefficients must be an integer greater than zero")
141+
step = N / noOfGridpoints
142+
if float(step).is_integer():
143+
new_vals = self.gridValues[::int(step)]
144+
return CircularGridDistribution(new_vals, self.enforcePdfNonnegative)
145+
elif N < noOfGridpoints:
146+
warnings.warn("Less coefficients than desired, interpolate using Fourier while ensuring nonnegativity.")
147+
self.enforcePdfNonnegative = True
148+
return CircularGridDistribution.fromDistribution(self, noOfGridpoints, self.enforcePdfNonnegative)
149+
else:
150+
warnings.warn("Cannot downsample directly. Transforming to Fourier to interpolate.")
151+
return CircularGridDistribution.fromDistribution(self, noOfGridpoints, self.enforcePdfNonnegative)
152+
153+
def normalize(self, tol=1e-2, warnUnnorm=True):
154+
# Call normalization from the HypertoroidalGridDistribution
155+
return super().normalize(tol=tol, warnUnnorm=warnUnnorm)
156+
157+
def shift(self, angle):
158+
if not np.isscalar(angle):
159+
raise ValueError("Angle must be a scalar.")
160+
fd = CircularFourierDistribution.fromFunctionValues(self.gridValues, len(self.gridValues), 'identity')
161+
fd = fd.shift(angle)
162+
new_vals, _ = fd.pdfOnGrid(len(self.gridValues))
163+
return CircularGridDistribution(new_vals, self.enforcePdfNonnegative)
164+
165+
def getClosestPoint(self, xa):
166+
xa = np.array(xa)
167+
N = len(self.gridValues)
168+
# MATLAB: indices = mod(round(xa/(2*pi/N)), N) + 1 (1-indexed)
169+
# In Python, we compute zero-indexed indices.
170+
indices = (np.round(xa / (2 * np.pi / N)) % N).astype(int)
171+
points = indices * (2 * np.pi / N)
172+
return points, indices
173+
174+
# --- Static methods ---
175+
@staticmethod
176+
def fromDistribution(distribution, noOfGridpoints, enforcePdfNonnegative=True):
177+
if isinstance(distribution, CicularFourierDistribution):
178+
with warnings.catch_warnings():
179+
warnings.simplefilter("ignore")
180+
fdToConv = distribution.truncate(noOfGridpoints)
181+
# The MATLAB code uses ifftshift and then ifft; here we mimic that:
182+
c_shifted = np.fft.ifftshift(fdToConv.c)
183+
valsOnGrid = np.real(np.fft.ifft(c_shifted)) * (len(fdToConv.a) + len(fdToConv.b))
184+
if fdToConv.transformation == 'identity':
185+
if np.any(valsOnGrid < 0):
186+
warnings.warn("This is an inaccurate transformation because negative values occurred. They are increased to 0.")
187+
valsOnGrid = np.maximum(valsOnGrid, 0)
188+
if enforcePdfNonnegative:
189+
warnings.warn("Interpolation differences may lead to inaccuracies; consider setting enforcePdfNonnegative to False.")
190+
elif fdToConv.transformation == 'sqrt':
191+
if np.any(valsOnGrid < 0) and enforcePdfNonnegative:
192+
warnings.warn("Negative values occurred in the sqrt transformation. Consider converting to identity first.")
193+
elif (not enforcePdfNonnegative) and noOfGridpoints < (2 * (len(distribution.a) + len(distribution.b)) - 1):
194+
warnings.warn("Interpolation differences may lead to inaccuracies with too few coefficients.")
195+
valsOnGrid = valsOnGrid ** 2
196+
else:
197+
raise ValueError("Transformation unsupported")
198+
return CircularGridDistribution(valsOnGrid, enforcePdfNonnegative)
199+
else:
200+
return CircularGridDistribution.fromFunction(lambda x: distribution.pdf(x),
201+
noOfGridpoints,
202+
enforcePdfNonnegative)
203+
204+
@staticmethod
205+
def fromFunction(fun, noOfCoefficients, enforcePdfNonnegative=True):
206+
gridPoints = np.linspace(0, 2*np.pi, noOfCoefficients, endpoint=False)
207+
gridValues = np.array(fun(gridPoints))
208+
return CircularGridDistribution(gridValues, enforcePdfNonnegative)
209+
210+
@staticmethod
211+
def fromFunctionValues(fvals, noOfGridpoints, enforcePdfNonnegative=True):
212+
fvals = np.array(fvals)
213+
step = len(fvals) / noOfGridpoints
214+
if not float(step).is_integer():
215+
raise ValueError("Number of function values has to be a multiple of noOfGridpoints")
216+
step = int(step)
217+
fvals_reduced = fvals[::step]
218+
return CircularGridDistribution(fvals_reduced, enforcePdfNonnegative)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import unittest
2+
import numpy as np
3+
from pyrecest.distributions.circle.circular_grid_distribution import CircularGridDistribution
4+
from pyrecest.distributions import VonMisesDistribution, WrappedNormalDistribution
5+
6+
class CircularGridDistributionTest(unittest.TestCase):
7+
8+
@staticmethod
9+
def _test_grid_conversion(dist, coeffs, enforceNonnegative, tolerance):
10+
figd = CircularGridDistribution.from_distribution(dist, coeffs, enforce_pdf_nonnegative=enforceNonnegative)
11+
# Test grid values
12+
xvals = np.linspace(0, 2*np.pi, coeffs, endpoint=False)
13+
np.testing.assert_allclose(figd.pdf(xvals), dist.pdf(xvals), atol=tolerance)
14+
# Test approximation of pdf
15+
xvals = np.arange(-2 * np.pi, 3 * np.pi, 0.01)
16+
np.testing.assert_allclose(figd.pdf(xvals), dist.pdf(xvals), atol=tolerance)
17+
18+
def test_VMToGridId(self):
19+
mu = 0.4
20+
for kappa in np.arange(.1, 2.1, .1):
21+
dist = VonMisesDistribution(mu, kappa)
22+
self._test_grid_conversion(dist, 101, False, 1E-8)
23+
24+
def test_VMToGridSqrt(self):
25+
mu = 0.5
26+
for kappa in np.arange(.1, 2.1, .1):
27+
dist = VonMisesDistribution(mu, kappa)
28+
self._test_grid_conversion(dist, 101, True, 1E-8)
29+
30+
def test_WNToGridId(self):
31+
mu = 0.8
32+
for sigma in np.arange(.2, 2.1, .1):
33+
dist = WrappedNormalDistribution(mu, sigma)
34+
self._test_grid_conversion(dist, 101, False, 1E-8)
35+
36+
def test_WNToGridSqrt(self):
37+
mu = 0.9
38+
for sigma in np.arange(.2, 2.1, .1):
39+
dist = WrappedNormalDistribution(mu, sigma)
40+
self._test_grid_conversion(dist, 101, True, 1E-8)
41+
42+
if __name__ == "__main__":
43+
unittest.main()

0 commit comments

Comments
 (0)