Skip to content

Commit e363248

Browse files
committed
Hacking during meeting
1 parent bfa8372 commit e363248

1 file changed

Lines changed: 53 additions & 29 deletions

File tree

sumpy/recurrence.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
from __future__ import annotations
3030

31+
from typing import TypeVar
32+
3133

3234
__copyright__ = """
3335
Copyright (C) 2024 Hirish Chandrasekaran
@@ -54,11 +56,18 @@
5456
THE SOFTWARE.
5557
"""
5658
import math
59+
5760
import numpy as np
5861
import sympy as sp
62+
5963
from pytools.obj_array import make_obj_array
64+
6065
from sumpy.expansion.diff_op import (
61-
DerivativeIdentifier, make_identity_diff_op, laplacian, LinearPDESystemOperator)
66+
DerivativeIdentifier,
67+
LinearPDESystemOperator,
68+
laplacian,
69+
make_identity_diff_op,
70+
)
6271

6372

6473
# similar to make_sym_vector in sumpy.symbolic, but returns an object array
@@ -166,8 +175,11 @@ def ode_in_r_to_x(ode_in_r: sp.Expr, var: np.ndarray, ode_order: int) -> sp.Expr
166175
return ode_in_x
167176

168177

178+
ODECoefficients = list[list[sp.Expr]]
179+
180+
169181
def ode_in_x_to_coeff_array(poly: sp.Poly, ode_order: int,
170-
var: np.ndarray) -> list:
182+
var: np.ndarray) -> ODECoefficients:
171183
r"""
172184
Organizes the coefficients of an ODE in the :math:`x_0` variable into a 2D array.
173185
@@ -184,11 +196,26 @@ def ode_in_x_to_coeff_array(poly: sp.Poly, ode_order: int,
184196
:math:`x_0`, so that, in terms of the above form, coeffs is
185197
:math:`[[b_{00}, b_{01}, ...], [b_{10}, b_{11}, ...], ...]`
186198
"""
187-
def kronecker(i, n=ode_order+1):
188-
return tuple(1 if i == j else 0 for j in range(n))
199+
return [
200+
# recast ODE coefficient obtained below as polynomial in x0
201+
sp.Poly(
202+
# get coefficient of deriv_ind'th derivative
203+
poly.coeff_monomial(poly.gens[deriv_ind]),
204+
205+
var[0])
206+
# get poly coefficients in /ascending/ order
207+
.all_coeffs()[::-1]
208+
for deriv_ind in range(ode_order+1)]
209+
189210

190-
return [sp.Poly(poly.coeff_monomial(kronecker(deriv_ind)),
191-
var[0]).all_coeffs()[::-1] for deriv_ind in range(ode_order+1)]
211+
NumberT = TypeVar("NumberT", int, float, complex)
212+
213+
214+
def _falling_factorial(arg: NumberT, num_terms: int) -> NumberT:
215+
result = 1
216+
for i in range(num_terms):
217+
result = result * (arg - i)
218+
return result
192219

193220

194221
def _auto_product_rule_single_term(p: int, m: int, var: np.ndarray) -> sp.Expr:
@@ -198,21 +225,15 @@ def _auto_product_rule_single_term(p: int, m: int, var: np.ndarray) -> sp.Expr:
198225
variable.
199226
We let :math:`s(i)` represent the ith order derivative of f when
200227
we output the final result.
201-
:arg p: see description
202-
:arg m: see description
203228
:arg var: array of sympy variables :math:`[x_0, x_1, \dots]`
204229
"""
205230
n = sp.symbols("n")
206231
s = sp.Function("s")
207-
result = 0
208-
for i in range(p+1):
209-
temp = 1
210-
for j in range(i):
211-
temp *= (n - j)
212-
# pylint: disable=not-callable
213-
temp *= math.comb(p, i) * s(n-i+m) * var[0]**(p-i)
214-
result += temp
215-
return result
232+
return sum(
233+
_falling_factorial(n, i)
234+
* math.comb(p, i) * s(n-i+m) * var[0]**(p-i)
235+
for i in range(p+1)
236+
)
216237

217238

218239
def recurrence_from_coeff_array(coeffs: list, var: np.ndarray) -> sp.Expr:
@@ -225,8 +246,8 @@ def recurrence_from_coeff_array(coeffs: list, var: np.ndarray) -> sp.Expr:
225246
:arg var: array of sympy variables :math:`[x_0, x_1, \dots]`
226247
"""
227248
final_recurrence = 0
228-
#Outer loop is derivative direction
229-
#Inner is polynomial order of x_0
249+
# Outer loop is derivative direction
250+
# Inner is polynomial order of x_0
230251
for m, _ in enumerate(coeffs):
231252
for p, _ in enumerate(coeffs[m]):
232253
final_recurrence += coeffs[m][p] * _auto_product_rule_single_term(p,
@@ -246,7 +267,7 @@ def recurrence_from_pde(pde: LinearPDESystemOperator) -> sp.Expr:
246267
ode_in_r, var, ode_order = pde_to_ode_in_r(pde)
247268
ode_in_x = ode_in_r_to_x(ode_in_r, var, ode_order).simplify()
248269
ode_in_x_cleared = (ode_in_x * var[0]**(ode_order+1)).simplify()
249-
#ode_in_x_cleared shouldn't have rational function coefficients in the coord.
270+
# ode_in_x_cleared shouldn't have rational function coefficients in the coord.
250271
assert sp.together(ode_in_x_cleared) == ode_in_x_cleared
251272
f_x_derivs = _make_sympy_vec("f_x", ode_order+1)
252273
poly = sp.Poly(ode_in_x_cleared, *f_x_derivs)
@@ -311,8 +332,8 @@ def test_recurrence_finder_helmholtz_three_d():
311332
"""
312333
Tests our recurrence relation generator for Helmhotlz 3D.
313334
"""
314-
#We are creating the recurrence relation for helmholtz3d which
315-
#seems to be an order 5 recurrence relation
335+
# We are creating the recurrence relation for helmholtz3d which
336+
# seems to be an order 5 recurrence relation
316337
w = make_identity_diff_op(3)
317338
helmholtz3d = laplacian(w) + w
318339
r = recurrence_from_pde(helmholtz3d)
@@ -326,24 +347,27 @@ def deriv_helmholtz_three_d(i, s_loc):
326347
) / (sp.sqrt(x**2 + y**2 + z**2))
327348
return sp.diff(true_f, x, i).subs(x, s_x).subs(
328349
y, s_y).subs(z, s_z)
329-
#Create relevant symbols
350+
# Create relevant symbols
330351
var = _make_sympy_vec("x", 3)
331352
n = sp.symbols("n")
332353
s = sp.Function("s")
333354

334-
#Create random source location
335-
s_loc = np.random.rand(3)
355+
rng = np.random.default_rng()
356+
357+
# Create random source location
358+
s_loc = rng.uniform(size=3)
336359

337-
#Create random order to check
338-
d = np.random.randint(0, 5)
360+
# Create random order to check
361+
from random import randrange
362+
d = randrange(0, 5)
339363

340-
#Substitute random location into recurrence relation and value of n = d
364+
# Substitute random location into recurrence relation and value of n = d
341365
r_loc = r.subs(var[0], s_loc[0])
342366
r_loc = r_loc.subs(var[1], s_loc[1])
343367
r_loc = r_loc.subs(var[2], s_loc[2])
344368
r_sub = r_loc.subs(n, d)
345369

346-
#Checking that the recurrence holds to some machine epsilon
370+
# Checking that the recurrence holds to some machine epsilon
347371
for i in range(max(d-3, 0), d+3):
348372
# pylint: disable=not-callable
349373
r_sub = r_sub.subs(s(i), deriv_helmholtz_three_d(i, s_loc))

0 commit comments

Comments
 (0)