Skip to content

Commit cfabd12

Browse files
author
Abhishek Bagusetty
committed
_iterative: use math.* for host scalar arithmetic; numpy stays for finfo/dtype/lstsq
1 parent f910092 commit cfabd12

1 file changed

Lines changed: 13 additions & 12 deletions

File tree

dpnp/scipy/sparse/linalg/_iterative.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070

7171
from __future__ import annotations
7272

73+
import math
7374
from typing import Callable
7475

7576
import dpctl.utils as dpu
@@ -88,7 +89,7 @@
8889

8990

9091
def _np_dtype(dp_dtype) -> numpy.dtype:
91-
"""Normalise any dtype-like (dpnp type/numpy type/string) to numpy.dtype."""
92+
"""Normalise any dtype-like to numpy.dtype."""
9293
return numpy.dtype(dp_dtype)
9394

9495

@@ -545,7 +546,7 @@ def cg(
545546
if rnorm_host <= atol_eff_host:
546547
info = 0
547548
break
548-
if not numpy.isfinite(rnorm_host):
549+
if not math.isfinite(rnorm_host):
549550
# IEEE-propagated breakdown: pAp or rz collapsed in the
550551
# previous iteration, poisoning r via alpha=inf/NaN. The
551552
# current iterate is the best estimate we have; report
@@ -688,7 +689,7 @@ def gmres(
688689
# the convergence test and the final maxiter check below operate on
689690
# host scalars (one explicit sync per restart, not an implicit one
690691
# per comparison).
691-
r_norm_host = numpy.inf
692+
r_norm_host = math.inf
692693
while True:
693694
mx = psolve(x)
694695
r = b - matvec(mx)
@@ -714,8 +715,8 @@ def gmres(
714715
# by compute_hu; without this reset stale values from the
715716
# previous restart would leak into the system.
716717
H[:] = 0
717-
# RHS for the Hessenberg system is r_norm * e_1; the rest of
718-
# e_host stays zero from the numpy.zeros allocation above.
718+
# RHS for the Hessenberg system is r_norm * e_1; the rest
719+
# of e_host stays zero from the host allocation above.
719720
e_host[0] = r_norm_host
720721
if iters > 0:
721722
# Clear stale tail from previous restart in case maxiter
@@ -874,7 +875,7 @@ def minres(
874875
if beta1 == 0:
875876
return (x, 0)
876877

877-
beta1 = numpy.sqrt(beta1)
878+
beta1 = math.sqrt(beta1)
878879

879880
if check:
880881
# See if A is symmetric. All on device; only the bool syncs.
@@ -937,7 +938,7 @@ def minres(
937938
beta = float(dpnp.inner(r2, y))
938939
if beta < 0:
939940
raise ValueError("non-symmetric matrix")
940-
beta = numpy.sqrt(beta)
941+
beta = math.sqrt(beta)
941942

942943
tnorm2 += alpha**2 + oldb**2 + beta**2
943944

@@ -953,10 +954,10 @@ def minres(
953954
gbar = sn * dbar - cs * alpha
954955
epsln = sn * beta
955956
dbar = -cs * beta
956-
root = numpy.sqrt(gbar**2 + dbar**2)
957+
root = math.hypot(gbar, dbar)
957958

958959
# Compute the next plane rotation Q_k.
959-
gamma = numpy.sqrt(gbar**2 + beta**2)
960+
gamma = math.hypot(gbar, beta)
960961
gamma = max(gamma, eps)
961962
cs = gbar / gamma
962963
sn = beta / gamma
@@ -980,7 +981,7 @@ def minres(
980981
# ----------------------------------------------------------
981982
# Estimate norms and test for convergence.
982983
# ----------------------------------------------------------
983-
Anorm = numpy.sqrt(tnorm2)
984+
Anorm = math.sqrt(tnorm2)
984985
ynorm = float(dpnp.linalg.norm(x)) # host sync #3
985986
epsa = Anorm * eps
986987
epsx = Anorm * ynorm * eps
@@ -992,11 +993,11 @@ def minres(
992993
qrnorm = phibar
993994
rnorm = qrnorm
994995
if ynorm == 0 or Anorm == 0:
995-
test1 = numpy.inf
996+
test1 = math.inf
996997
else:
997998
test1 = rnorm / (Anorm * ynorm) # ||r|| / (||A|| ||x||)
998999
if Anorm == 0:
999-
test2 = numpy.inf
1000+
test2 = math.inf
10001001
else:
10011002
test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
10021003

0 commit comments

Comments
 (0)