7070
7171from __future__ import annotations
7272
73+ import math
7374from typing import Callable
7475
7576import dpctl .utils as dpu
8889
8990
9091def _np_dtype (dp_dtype ) -> numpy .dtype :
91- """Normalise any dtype-like (dpnp type/numpy type/string) to numpy.dtype."""
92+ """Normalise any dtype-like to numpy.dtype."""
9293 return numpy .dtype (dp_dtype )
9394
9495
@@ -545,7 +546,7 @@ def cg(
545546 if rnorm_host <= atol_eff_host :
546547 info = 0
547548 break
548- if not numpy .isfinite (rnorm_host ):
549+ if not math .isfinite (rnorm_host ):
549550 # IEEE-propagated breakdown: pAp or rz collapsed in the
550551 # previous iteration, poisoning r via alpha=inf/NaN. The
551552 # current iterate is the best estimate we have; report
@@ -688,7 +689,7 @@ def gmres(
688689 # the convergence test and the final maxiter check below operate on
689690 # host scalars (one explicit sync per restart, not an implicit one
690691 # per comparison).
691- r_norm_host = numpy .inf
692+ r_norm_host = math .inf
692693 while True :
693694 mx = psolve (x )
694695 r = b - matvec (mx )
@@ -714,8 +715,8 @@ def gmres(
714715 # by compute_hu; without this reset stale values from the
715716 # previous restart would leak into the system.
716717 H [:] = 0
717- # RHS for the Hessenberg system is r_norm * e_1; the rest of
718- # e_host stays zero from the numpy.zeros allocation above.
718+ # RHS for the Hessenberg system is r_norm * e_1; the rest
719+ # of e_host stays zero from the host allocation above.
719720 e_host [0 ] = r_norm_host
720721 if iters > 0 :
721722 # Clear stale tail from previous restart in case maxiter
@@ -874,7 +875,7 @@ def minres(
874875 if beta1 == 0 :
875876 return (x , 0 )
876877
877- beta1 = numpy .sqrt (beta1 )
878+ beta1 = math .sqrt (beta1 )
878879
879880 if check :
880881 # See if A is symmetric. All on device; only the bool syncs.
@@ -937,7 +938,7 @@ def minres(
937938 beta = float (dpnp .inner (r2 , y ))
938939 if beta < 0 :
939940 raise ValueError ("non-symmetric matrix" )
940- beta = numpy .sqrt (beta )
941+ beta = math .sqrt (beta )
941942
942943 tnorm2 += alpha ** 2 + oldb ** 2 + beta ** 2
943944
@@ -953,10 +954,10 @@ def minres(
953954 gbar = sn * dbar - cs * alpha
954955 epsln = sn * beta
955956 dbar = - cs * beta
956- root = numpy . sqrt (gbar ** 2 + dbar ** 2 )
957+ root = math . hypot (gbar , dbar )
957958
958959 # Compute the next plane rotation Q_k.
959- gamma = numpy . sqrt (gbar ** 2 + beta ** 2 )
960+ gamma = math . hypot (gbar , beta )
960961 gamma = max (gamma , eps )
961962 cs = gbar / gamma
962963 sn = beta / gamma
@@ -980,7 +981,7 @@ def minres(
980981 # ----------------------------------------------------------
981982 # Estimate norms and test for convergence.
982983 # ----------------------------------------------------------
983- Anorm = numpy .sqrt (tnorm2 )
984+ Anorm = math .sqrt (tnorm2 )
984985 ynorm = float (dpnp .linalg .norm (x )) # host sync #3
985986 epsa = Anorm * eps
986987 epsx = Anorm * ynorm * eps
@@ -992,11 +993,11 @@ def minres(
992993 qrnorm = phibar
993994 rnorm = qrnorm
994995 if ynorm == 0 or Anorm == 0 :
995- test1 = numpy .inf
996+ test1 = math .inf
996997 else :
997998 test1 = rnorm / (Anorm * ynorm ) # ||r|| / (||A|| ||x||)
998999 if Anorm == 0 :
999- test2 = numpy .inf
1000+ test2 = math .inf
10001001 else :
10011002 test2 = root / Anorm # ||Ar|| / (||A|| ||r||)
10021003
0 commit comments