2828
2929from __future__ import annotations
3030
31+ from typing import TypeVar
32+
3133
3234__copyright__ = """
3335Copyright (C) 2024 Hirish Chandrasekaran
5456THE SOFTWARE.
5557"""
5658import math
59+
5760import numpy as np
5861import sympy as sp
62+
5963from pytools .obj_array import make_obj_array
64+
6065from sumpy .expansion .diff_op import (
61- DerivativeIdentifier , make_identity_diff_op , laplacian , LinearPDESystemOperator )
66+ DerivativeIdentifier ,
67+ LinearPDESystemOperator ,
68+ laplacian ,
69+ make_identity_diff_op ,
70+ )
6271
6372
6473# similar to make_sym_vector in sumpy.symbolic, but returns an object array
@@ -166,8 +175,11 @@ def ode_in_r_to_x(ode_in_r: sp.Expr, var: np.ndarray, ode_order: int) -> sp.Expr
166175 return ode_in_x
167176
168177
178+ ODECoefficients = list [list [sp .Expr ]]
179+
180+
169181def ode_in_x_to_coeff_array (poly : sp .Poly , ode_order : int ,
170- var : np .ndarray ) -> list :
182+ var : np .ndarray ) -> ODECoefficients :
171183 r"""
172184 Organizes the coefficients of an ODE in the :math:`x_0` variable into a 2D array.
173185
@@ -184,11 +196,26 @@ def ode_in_x_to_coeff_array(poly: sp.Poly, ode_order: int,
184196 :math:`x_0`, so that, in terms of the above form, coeffs is
185197 :math:`[[b_{00}, b_{01}, ...], [b_{10}, b_{11}, ...], ...]`
186198 """
187- def kronecker (i , n = ode_order + 1 ):
188- return tuple (1 if i == j else 0 for j in range (n ))
199+ return [
200+ # recast ODE coefficient obtained below as polynomial in x0
201+ sp .Poly (
202+ # get coefficient of deriv_ind'th derivative
203+ poly .coeff_monomial (poly .gens [deriv_ind ]),
204+
205+ var [0 ])
206+ # get poly coefficients in /ascending/ order
207+ .all_coeffs ()[::- 1 ]
208+ for deriv_ind in range (ode_order + 1 )]
209+
189210
190- return [sp .Poly (poly .coeff_monomial (kronecker (deriv_ind )),
191- var [0 ]).all_coeffs ()[::- 1 ] for deriv_ind in range (ode_order + 1 )]
211+ NumberT = TypeVar ("NumberT" , int , float , complex )
212+
213+
214+ def _falling_factorial (arg : NumberT , num_terms : int ) -> NumberT :
215+ result = 1
216+ for i in range (num_terms ):
217+ result = result * (arg - i )
218+ return result
192219
193220
194221def _auto_product_rule_single_term (p : int , m : int , var : np .ndarray ) -> sp .Expr :
@@ -198,21 +225,15 @@ def _auto_product_rule_single_term(p: int, m: int, var: np.ndarray) -> sp.Expr:
198225 variable.
199226 We let :math:`s(i)` represent the ith order derivative of f when
200227 we output the final result.
201- :arg p: see description
202- :arg m: see description
203228 :arg var: array of sympy variables :math:`[x_0, x_1, \dots]`
204229 """
205230 n = sp .symbols ("n" )
206231 s = sp .Function ("s" )
207- result = 0
208- for i in range (p + 1 ):
209- temp = 1
210- for j in range (i ):
211- temp *= (n - j )
212- # pylint: disable=not-callable
213- temp *= math .comb (p , i ) * s (n - i + m ) * var [0 ]** (p - i )
214- result += temp
215- return result
232+ return sum (
233+ _falling_factorial (n , i )
234+ * math .comb (p , i ) * s (n - i + m ) * var [0 ]** (p - i )
235+ for i in range (p + 1 )
236+ )
216237
217238
218239def recurrence_from_coeff_array (coeffs : list , var : np .ndarray ) -> sp .Expr :
@@ -225,8 +246,8 @@ def recurrence_from_coeff_array(coeffs: list, var: np.ndarray) -> sp.Expr:
225246 :arg var: array of sympy variables :math:`[x_0, x_1, \dots]`
226247 """
227248 final_recurrence = 0
228- #Outer loop is derivative direction
229- #Inner is polynomial order of x_0
249+ # Outer loop is derivative direction
250+ # Inner is polynomial order of x_0
230251 for m , _ in enumerate (coeffs ):
231252 for p , _ in enumerate (coeffs [m ]):
232253 final_recurrence += coeffs [m ][p ] * _auto_product_rule_single_term (p ,
@@ -246,7 +267,7 @@ def recurrence_from_pde(pde: LinearPDESystemOperator) -> sp.Expr:
246267 ode_in_r , var , ode_order = pde_to_ode_in_r (pde )
247268 ode_in_x = ode_in_r_to_x (ode_in_r , var , ode_order ).simplify ()
248269 ode_in_x_cleared = (ode_in_x * var [0 ]** (ode_order + 1 )).simplify ()
249- #ode_in_x_cleared shouldn't have rational function coefficients in the coord.
270+ # ode_in_x_cleared shouldn't have rational function coefficients in the coord.
250271 assert sp .together (ode_in_x_cleared ) == ode_in_x_cleared
251272 f_x_derivs = _make_sympy_vec ("f_x" , ode_order + 1 )
252273 poly = sp .Poly (ode_in_x_cleared , * f_x_derivs )
@@ -311,8 +332,8 @@ def test_recurrence_finder_helmholtz_three_d():
311332 """
312333 Tests our recurrence relation generator for Helmhotlz 3D.
313334 """
314- #We are creating the recurrence relation for helmholtz3d which
315- #seems to be an order 5 recurrence relation
335+ # We are creating the recurrence relation for helmholtz3d which
336+ # seems to be an order 5 recurrence relation
316337 w = make_identity_diff_op (3 )
317338 helmholtz3d = laplacian (w ) + w
318339 r = recurrence_from_pde (helmholtz3d )
@@ -326,24 +347,27 @@ def deriv_helmholtz_three_d(i, s_loc):
326347 ) / (sp .sqrt (x ** 2 + y ** 2 + z ** 2 ))
327348 return sp .diff (true_f , x , i ).subs (x , s_x ).subs (
328349 y , s_y ).subs (z , s_z )
329- #Create relevant symbols
350+ # Create relevant symbols
330351 var = _make_sympy_vec ("x" , 3 )
331352 n = sp .symbols ("n" )
332353 s = sp .Function ("s" )
333354
334- #Create random source location
335- s_loc = np .random .rand (3 )
355+ rng = np .random .default_rng ()
356+
357+ # Create random source location
358+ s_loc = rng .uniform (size = 3 )
336359
337- #Create random order to check
338- d = np .random .randint (0 , 5 )
360+ # Create random order to check
361+ from random import randrange
362+ d = randrange (0 , 5 )
339363
340- #Substitute random location into recurrence relation and value of n = d
364+ # Substitute random location into recurrence relation and value of n = d
341365 r_loc = r .subs (var [0 ], s_loc [0 ])
342366 r_loc = r_loc .subs (var [1 ], s_loc [1 ])
343367 r_loc = r_loc .subs (var [2 ], s_loc [2 ])
344368 r_sub = r_loc .subs (n , d )
345369
346- #Checking that the recurrence holds to some machine epsilon
370+ # Checking that the recurrence holds to some machine epsilon
347371 for i in range (max (d - 3 , 0 ), d + 3 ):
348372 # pylint: disable=not-callable
349373 r_sub = r_sub .subs (s (i ), deriv_helmholtz_three_d (i , s_loc ))
0 commit comments