@@ -29,6 +29,13 @@ Some notes:
2929#include " utils/debugTools.h"
3030#endif
3131namespace lanczos {
32+
33+ struct MatrixDot {
34+
35+ virtual void operator ()(real* Mv, real*v) = 0;
36+
37+ };
38+
3239 struct Solver {
3340 Solver (real tolerance = 1e-3 );
3441
@@ -40,8 +47,20 @@ namespace lanczos{
4047
4148 // Given a Dotctor that computes a product M·v (where M is handled by Dotctor ), computes Bv = sqrt(M)·v
4249 // Returns the number of iterations performed
43- template <class Dotctor > // B = sqrt(M)
44- int solve (Dotctor &dot, real *Bv, real* v, int N);
50+ // B = sqrt(M)
51+ int solve (MatrixDot *dot, real *Bv, real* v, int N);
52+
53+ // Overload for a shared_ptr
54+ int solve (std::shared_ptr<MatrixDot> dot, real *Bv, real* v, int N){
55+ return this ->solve (dot.get (), Bv, v, N);
56+ }
57+
58+ // Overload for an instance
59+ template <class SomeDot >
60+ int solve (SomeDot &dot, real *Bv, real* v, int N){
61+ MatrixDot* ptr = static_cast <MatrixDot*>(&dot);
62+ return this ->solve (ptr, Bv, v, N);
63+ }
4564
4665 // You can use this array as input to the solve operation, which will save some memory
4766 real * getV (int N){
@@ -70,8 +89,7 @@ namespace lanczos{
7089 bool checkConvergence (int current_iterations, real *Bz, real normNoise_prev);
7190 real computeError (real* Bz, real normNoise_prev);
7291 real computeNorm (real *v, int numberElements);
73- template <class Dotctor >
74- void computeIteration (Dotctor &dot, int i, real invz2);
92+ void computeIteration (MatrixDot *dot, int i, real invz2);
7593 void registerRequiredStepsForConverge (int steps_needed);
7694 void resizeIfNeeded (real*z, int N);
7795 int N;
@@ -94,105 +112,6 @@ namespace lanczos{
94112 int check_convergence_steps;
95113 real tolerance;
96114 };
97-
98- template <class Dotctor >
99- int Solver::solve (Dotctor &dot, real *Bz, real*z, int N){
100- // Handles the case of the number of elements changing since last call
101- if (N != this ->N ){
102- real * d_V = detail::getRawPointer (V);
103- if (z == d_V){
104- throw std::runtime_error (" [Lanczos] Size mismatch in input" );
105- }
106- numElementsChanged (N);
107- }
108- /* See algorithm I in [1]*/
109- /* ***********v[0] = z/||z||_2*****/
110- /* If z is not the array provided by getV*/
111- real* d_V = detail::getRawPointer (V);
112- if (z != d_V){
113- detail::device_copy (z, z+N, V.begin ());
114- }
115- /* 1/norm(z)*/
116- real invz2 = 1.0 /computeNorm (d_V, N);
117- /* v[0] = v[0]*1/norm(z)*/
118- device_scal (N, &invz2, d_V, 1 );
119- /* Lanczos iterations for Krylov decomposition*/
120- /* Will perform iterations until Error<=tolerance*/
121- int i = -1 ;
122- real normResult_prev = 1.0 ; // For error estimation, see eq 27 in [1]
123- while (true ){
124- i++;
125- /* Allocate more space if needed*/
126- if (i == max_iter-1 ){
127- #ifdef CUDA_ENABLED
128- CudaSafeCall (cudaDeviceSynchronize ());
129- #endif
130- this ->incrementMaxIterations (2 );
131- }
132- computeIteration (dot, i, invz2);
133- /* Check convergence if needed*/
134- if (i >= check_convergence_steps){ // Miminum of 3 iterations, will be auto tuned
135- /* Compute Bz using h and z*/
136- /* *** y = ||z||_2 * Vm · H^1/2 · e_1 *****/
137- this ->computeCurrentResultEstimation (i, Bz, 1.0 /invz2);
138- /* The first time the result is computed it is only stored as oldBz*/
139- if (i>check_convergence_steps){
140- if (checkConvergence (i, Bz, normResult_prev)){
141- return i;
142- }
143- }
144- /* Always save the current result as oldBz*/
145- detail::device_copy (Bz, Bz+N, oldBz.begin ());
146- /* Store the norm of the result*/
147- real * d_oldBz = detail::getRawPointer (oldBz);
148- device_nrm2 (N, d_oldBz, 1 , &normResult_prev);
149- }
150- }
151- }
152-
153-
154- template <class Dotctor >
155- void Solver::computeIteration (Dotctor &dot, int i, real invz2){
156- real* d_V = detail::getRawPointer (V);
157- real * d_w = detail::getRawPointer (w);
158- /* w = D·vi*/
159- dot (d_V+N*i, d_w);
160- if (i>0 ){
161- /* w = w-h[i-1][i]·vi*/
162- real alpha = -hsup[i-1 ];
163- device_axpy (N,
164- &alpha,
165- d_V+N*(i-1 ), 1 ,
166- d_w, 1 );
167- }
168- /* h[i][i] = dot(w, vi)*/
169- device_dot (N,
170- d_w, 1 ,
171- d_V+N*i, 1 ,
172- &(hdiag[i]));
173- if (i<max_iter-1 ){
174- /* w = w-h[i][i]·vi*/
175- real alpha = -hdiag[i];
176- device_axpy (N,
177- &alpha,
178- d_V+N*i, 1 ,
179- d_w, 1 );
180- /* h[i+1][i] = h[i][i+1] = norm(w)*/
181- device_nrm2 (N, (real*)d_w, 1 , &(hsup[i]));
182- /* v_(i+1) = w·1/ norm(w)*/
183- real tol = 1e-3 *hdiag[i]*invz2;
184- if (hsup[i]<tol) hsup[i] = real (0.0 );
185- if (hsup[i]>real (0.0 )){
186- real invw2 = 1.0 /hsup[i];
187- device_scal (N, &invw2, d_w, 1 );
188- }
189- else {/* If norm(w) = 0 that means all elements of w are zero, so set w = e1*/
190- detail::device_fill (w.begin (), w.end (), real ());
191- w[0 ] = 1 ;
192- }
193- detail::device_copy (w.begin (), w.begin ()+N, V.begin () + N*(i+1 ));
194- }
195- }
196115}
197116
198117#ifndef SHARED_LIBRARY_COMPILATION
0 commit comments