1+ from .abstract_circular_distribution import AbstractCircularDistribution
2+ from .circular_fourier_distribution import CircularFourierDistribution
3+
4+ from pyrecest .backend import array , pi , where , sin , linspace , ceil , floor , arange , sqrt , sum , maximum , mod , round , isclose , \
5+ import numpy as np
6+ import matplotlib .pyplot as plt
7+ import warnings
8+
9+ class CircularGridDistribution (AbstractCircularDistribution , HypertoroidalGridDistribution ):
10+ """
11+ Density representation using function values on a grid with Fourier interpolation.
12+ """
13+
14+ def __init__ (self , gridValues , enforcePdfNonnegative = True ):
15+ gridValues = array (gridValues )
16+ # Check if gridValues is already a distribution (in which case use fromDistribution)
17+ if isinstance (gridValues , AbstractCircularDistribution ):
18+ raise ValueError ("You gave a distribution as the first argument. "
19+ "To convert distributions to a distribution in grid representation, use fromDistribution." )
20+ # Call parent constructor (HypertoroidalGridDistribution)
21+ HypertoroidalGridDistribution .__init__ (self , None , gridValues , enforcePdfNonnegative )
22+ self .gridValues = gridValues
23+ self .enforcePdfNonnegative = enforcePdfNonnegative
24+
25+ def pdf (self , xs , useSinc = False , sincRepetitions = 5 ):
26+ xs = array (xs )
27+ if useSinc :
28+ if sincRepetitions % 2 != 1 :
29+ raise ValueError ("sincRepetitions must be an odd integer." )
30+ N = len (self .gridValues )
31+ step_size = 2 * pi / N
32+
33+ # Define MATLAB-style sinc: sinc(x) = sin(x)/x with sinc(0)=1
34+ def matlab_sinc (x ):
35+ return where (x == 0 , 1.0 , sin (x ) / x )
36+
37+ # Create the range vector: from -floor(sincRepetitions/2)*N to ceil(sincRepetitions/2)*N - 1
38+ lower = int (floor (sincRepetitions / 2 ) * N )
39+ upper = int (ceil (sincRepetitions / 2 ) * N )
40+ r = arange (- lower , upper )
41+
42+ # Compute the sinc values. (xs/step_size) becomes a column vector.
43+ sincVals = matlab_sinc ((xs / step_size )[:, None ] - r [None , :])
44+ # Tile the grid values; note that MATLAB’s repmat(gridValues', [1, sincRepetitions])
45+ if self .enforcePdfNonnegative :
46+ coeffs = np .tile (sqrt (self .gridValues ), sincRepetitions )
47+ p = sum (coeffs * sincVals , axis = 1 ) ** 2
48+ else :
49+ coeffs = np .tile (self .gridValues , sincRepetitions )
50+ p = sum (coeffs * sincVals , axis = 1 )
51+ return p
52+ else :
53+ N = len (self .gridValues )
54+ noCoeffs = N
55+ if N % 2 == 0 :
56+ noCoeffs += 1 # extra coefficient as in the MATLAB code
57+ transform = 'sqrt' if self .enforcePdfNonnegative else 'identity'
58+ fd = CircularFourierDistribution .fromFunctionValues (self .gridValues , noCoeffs , transform )
59+ return fd .pdf (xs )
60+
61+ def pdfOnGrid (self , noOfDesiredGridpoints ):
62+ N = len (self .gridValues )
63+ xGrid = np .linspace (0 , 2 * np .pi , N , endpoint = False )
64+ step = N / noOfDesiredGridpoints
65+ if not float (step ).is_integer ():
66+ raise ValueError ("Number of function values must be a multiple of noOfDesiredGridpoints" )
67+ step = int (step )
68+ vals = self .gridValues [::step ]
69+ return vals , xGrid
70+
71+ def trigonometricMoment (self , n ):
72+ N = len (self .gridValues )
73+ noCoeffs = N
74+ if N % 2 == 0 :
75+ noCoeffs += 1
76+ # In MATLAB a warning is suppressed here; in Python you might use warnings.filterwarnings.
77+ fd = CircularFourierDistribution .fromFunctionValues (self .gridValues , noCoeffs , 'identity' )
78+ return fd .trigonometricMoment (n )
79+
80+ def plot (self , * args , ** kwargs ):
81+ N = len (self .gridValues )
82+ gridPoints = np .linspace (0 , 2 * np .pi , N , endpoint = False )
83+ # Optionally, call a parent plot if available.
84+ super ().plot (* args , ** kwargs )
85+ # Then also plot the grid points as markers.
86+ plt .plot (gridPoints , self .gridValues , 'x' )
87+ plt .xlabel ("Angle" )
88+ plt .ylabel ("Grid Values" )
89+ plt .show ()
90+ # Returning the list of line handles is optional in Python.
91+ return
92+
93+ def value (self , xa ):
94+ """
95+ Evaluate the density at points when interpreted as a probability mass function.
96+ This implementation assumes that the grid points are used as indicators.
97+ """
98+ xa = np .array (xa )
99+ grid = self .getGrid ()
100+ # Find exact matches (using np.isclose for floating-point comparisons)
101+ result = np .zeros_like (xa )
102+ for i , angle in enumerate (xa ):
103+ match = np .where (np .isclose (grid , angle ))[0 ]
104+ result [i ] = self .gridValues [match [0 ]] if match .size > 0 else 0
105+ return result
106+
107+ def getGrid (self ):
108+ N = len (self .gridValues )
109+ return np .linspace (0 , 2 * np .pi , N , endpoint = False )
110+
111+ def getGridPoint (self , indices = None ):
112+ """
113+ If indices is None, returns the entire grid.
114+ Otherwise, returns grid points corresponding to (indices - 1)*2*pi/N.
115+ (MATLAB indices are 1-based; here we mimic that conversion.)
116+ """
117+ N = len (self .gridValues )
118+ if indices is None :
119+ return self .getGrid ()
120+ else :
121+ indices = np .array (indices )
122+ return (indices - 1 ) * (2 * np .pi / N )
123+
124+ def convolve (self , f2 ):
125+ if self .enforcePdfNonnegative != f2 .enforcePdfNonnegative :
126+ raise ValueError ("Mismatch in enforcePdfNonnegative between the two distributions." )
127+ if len (self .gridValues ) != len (f2 .gridValues ):
128+ raise ValueError ("Grid sizes must be identical for convolution." )
129+ N = len (self .gridValues )
130+ # Circular convolution using FFTs
131+ fft1 = np .fft .fft (self .gridValues )
132+ fft2 = np .fft .fft (f2 .gridValues )
133+ convResult = np .real (np .fft .ifft (fft1 * fft2 )) * (2 * np .pi / N )
134+ convResult [convResult < 0 ] = 0 # Remove small negative values due to numerical error
135+ return CircularGridDistribution (convResult , self .enforcePdfNonnegative )
136+
137+ def truncate (self , noOfGridpoints ):
138+ N = len (self .gridValues )
139+ if noOfGridpoints - 1 <= 0 :
140+ raise ValueError ("Number of coefficients must be an integer greater than zero" )
141+ step = N / noOfGridpoints
142+ if float (step ).is_integer ():
143+ new_vals = self .gridValues [::int (step )]
144+ return CircularGridDistribution (new_vals , self .enforcePdfNonnegative )
145+ elif N < noOfGridpoints :
146+ warnings .warn ("Less coefficients than desired, interpolate using Fourier while ensuring nonnegativity." )
147+ self .enforcePdfNonnegative = True
148+ return CircularGridDistribution .fromDistribution (self , noOfGridpoints , self .enforcePdfNonnegative )
149+ else :
150+ warnings .warn ("Cannot downsample directly. Transforming to Fourier to interpolate." )
151+ return CircularGridDistribution .fromDistribution (self , noOfGridpoints , self .enforcePdfNonnegative )
152+
153+ def normalize (self , tol = 1e-2 , warnUnnorm = True ):
154+ # Call normalization from the HypertoroidalGridDistribution
155+ return super ().normalize (tol = tol , warnUnnorm = warnUnnorm )
156+
157+ def shift (self , angle ):
158+ if not np .isscalar (angle ):
159+ raise ValueError ("Angle must be a scalar." )
160+ fd = CircularFourierDistribution .fromFunctionValues (self .gridValues , len (self .gridValues ), 'identity' )
161+ fd = fd .shift (angle )
162+ new_vals , _ = fd .pdfOnGrid (len (self .gridValues ))
163+ return CircularGridDistribution (new_vals , self .enforcePdfNonnegative )
164+
165+ def getClosestPoint (self , xa ):
166+ xa = np .array (xa )
167+ N = len (self .gridValues )
168+ # MATLAB: indices = mod(round(xa/(2*pi/N)), N) + 1 (1-indexed)
169+ # In Python, we compute zero-indexed indices.
170+ indices = (np .round (xa / (2 * np .pi / N )) % N ).astype (int )
171+ points = indices * (2 * np .pi / N )
172+ return points , indices
173+
174+ # --- Static methods ---
175+ @staticmethod
176+ def fromDistribution (distribution , noOfGridpoints , enforcePdfNonnegative = True ):
177+ if isinstance (distribution , CicularFourierDistribution ):
178+ with warnings .catch_warnings ():
179+ warnings .simplefilter ("ignore" )
180+ fdToConv = distribution .truncate (noOfGridpoints )
181+ # The MATLAB code uses ifftshift and then ifft; here we mimic that:
182+ c_shifted = np .fft .ifftshift (fdToConv .c )
183+ valsOnGrid = np .real (np .fft .ifft (c_shifted )) * (len (fdToConv .a ) + len (fdToConv .b ))
184+ if fdToConv .transformation == 'identity' :
185+ if np .any (valsOnGrid < 0 ):
186+ warnings .warn ("This is an inaccurate transformation because negative values occurred. They are increased to 0." )
187+ valsOnGrid = np .maximum (valsOnGrid , 0 )
188+ if enforcePdfNonnegative :
189+ warnings .warn ("Interpolation differences may lead to inaccuracies; consider setting enforcePdfNonnegative to False." )
190+ elif fdToConv .transformation == 'sqrt' :
191+ if np .any (valsOnGrid < 0 ) and enforcePdfNonnegative :
192+ warnings .warn ("Negative values occurred in the sqrt transformation. Consider converting to identity first." )
193+ elif (not enforcePdfNonnegative ) and noOfGridpoints < (2 * (len (distribution .a ) + len (distribution .b )) - 1 ):
194+ warnings .warn ("Interpolation differences may lead to inaccuracies with too few coefficients." )
195+ valsOnGrid = valsOnGrid ** 2
196+ else :
197+ raise ValueError ("Transformation unsupported" )
198+ return CircularGridDistribution (valsOnGrid , enforcePdfNonnegative )
199+ else :
200+ return CircularGridDistribution .fromFunction (lambda x : distribution .pdf (x ),
201+ noOfGridpoints ,
202+ enforcePdfNonnegative )
203+
204+ @staticmethod
205+ def fromFunction (fun , noOfCoefficients , enforcePdfNonnegative = True ):
206+ gridPoints = np .linspace (0 , 2 * np .pi , noOfCoefficients , endpoint = False )
207+ gridValues = np .array (fun (gridPoints ))
208+ return CircularGridDistribution (gridValues , enforcePdfNonnegative )
209+
210+ @staticmethod
211+ def fromFunctionValues (fvals , noOfGridpoints , enforcePdfNonnegative = True ):
212+ fvals = np .array (fvals )
213+ step = len (fvals ) / noOfGridpoints
214+ if not float (step ).is_integer ():
215+ raise ValueError ("Number of function values has to be a multiple of noOfGridpoints" )
216+ step = int (step )
217+ fvals_reduced = fvals [::step ]
218+ return CircularGridDistribution (fvals_reduced , enforcePdfNonnegative )
0 commit comments