Grid 0.7.0
AdefMrhs.h
Go to the documentation of this file.
1 /*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/AdefGeneric.h
6
7 Copyright (C) 2015
8
9Author: Peter Boyle <paboyle@ph.ed.ac.uk>
10
11 This program is free software; you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation; either version 2 of the License, or
14 (at your option) any later version.
15
16 This program is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
20
21 You should have received a copy of the GNU General Public License along
22 with this program; if not, write to the Free Software Foundation, Inc.,
23 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25 See the full license in the file "LICENSE" in the top level distribution directory
26 *************************************************************************************/
27 /* END LEGAL */
28#pragma once
29
30
31 /*
32 * Compared to Tang-2009: P=Pleft. P^T = PRight Q=MssInv.
33 * Script A = SolverMatrix
34 * Script P = Preconditioner
35 *
36 * Implement ADEF-2
37 *
38 * Vstart = P^Tx + Qb
39 * M1 = P^TM + Q
40 * M2=M3=1
41 */
43
44
45template<class Field>
47{
48 public:
52
53 // Fine operator, Smoother, CoarseSolver
57
65
66 /*
67 Field rrr;
68 Field sss;
69 Field qqq;
70 Field zzz;
71 */
72 // more most opertor functions
74 Integer maxit,
76 LinearFunction<Field> &Smoother,
77 GridBase *fine) :
78 Tolerance(tol),
79 MaxIterations(maxit),
80 _FineLinop(FineLinop),
81 _Smoother(Smoother)
82 /*
83 rrr(fine),
84 sss(fine),
85 qqq(fine),
86 zzz(fine)
87*/
88 {
89 grid = fine;
90 };
91
92 // Vector case
93 virtual void operator() (std::vector<Field> &src, std::vector<Field> &x)
94 {
95 // SolveSingleSystem(src,x);
96 SolvePrecBlockCG(src,x);
97 }
98
100// Thin QR factorisation (google it)
103 //Dimensions
104 // R_{ferm x Nblock} = Q_{ferm x Nblock} x C_{Nblock x Nblock} -> ferm x Nblock
105 //
106 // Rdag R = m_rr = Herm = L L^dag <-- Cholesky decomposition (LLT routine in Eigen)
107 //
108 // Q C = R => Q = R C^{-1}
109 //
110 // Want Ident = Q^dag Q = C^{-dag} R^dag R C^{-1} = C^{-dag} L L^dag C^{-1} = 1_{Nblock x Nblock}
111 //
112 // Set C = L^{dag}, and then Q^dag Q = ident
113 //
114 // Checks:
115 // Cdag C = Rdag R ; passes.
116 // QdagQ = 1 ; passes
118 void ThinQRfact (Eigen::MatrixXcd &m_zz,
119 Eigen::MatrixXcd &C,
120 Eigen::MatrixXcd &Cinv,
121 std::vector<Field> & Q,
122 std::vector<Field> & MQ,
123 const std::vector<Field> & Z,
124 const std::vector<Field> & MZ)
125 {
126 RealD t0=usecond();
127 _BlockCGLinalg.InnerProductMatrix(m_zz,MZ,Z);
128 RealD t1=usecond();
129
130 m_zz = 0.5*(m_zz+m_zz.adjoint());
131
132 Eigen::MatrixXcd L = m_zz.llt().matrixL();
133
134 C = L.adjoint();
135 Cinv = C.inverse();
136
137 RealD t3=usecond();
138 _BlockCGLinalg.MulMatrix( Q,Cinv,Z);
139 _BlockCGLinalg.MulMatrix(MQ,Cinv,MZ);
140 RealD t4=usecond();
141 std::cout << " ThinQRfact IP :"<< t1-t0<<" us"<<std::endl;
142 std::cout << " ThinQRfact Eigen :"<< t3-t1<<" us"<<std::endl;
143 std::cout << " ThinQRfact MulMat:"<< t4-t3<<" us"<<std::endl;
144 }
145
146 virtual void SolvePrecBlockCG (std::vector<Field> &src, std::vector<Field> &X)
147 {
148 std::cout << GridLogMessage<<"HDCG: mrhs fPrecBlockcg starting"<<std::endl;
149 src[0].Grid()->Barrier();
150 int nrhs = src.size();
151 // std::vector<RealD> f(nrhs);
152 // std::vector<RealD> rtzp(nrhs);
153 // std::vector<RealD> rtz(nrhs);
154 // std::vector<RealD> a(nrhs);
155 // std::vector<RealD> d(nrhs);
156 // std::vector<RealD> b(nrhs);
157 // std::vector<RealD> rptzp(nrhs);
158
160 //Initial residual computation & set up
162 std::vector<RealD> ssq(nrhs);
163 for(int rhs=0;rhs<nrhs;rhs++){
164 ssq[rhs]=norm2(src[rhs]); assert(ssq[rhs]!=0.0);
165 }
166
168 // Fields -- eliminate duplicates between fPcg and block cg
170 std::vector<Field> Mtmp(nrhs,grid);
171 std::vector<Field> tmp(nrhs,grid);
172 std::vector<Field> Z(nrhs,grid); // Rename Z to R
173 std::vector<Field> MZ(nrhs,grid); // Rename MZ to Z
174 std::vector<Field> Q(nrhs,grid); //
175 std::vector<Field> MQ(nrhs,grid); // Rename to P
176 std::vector<Field> D(nrhs,grid);
177 std::vector<Field> AD(nrhs,grid);
178
179 /************************************************************************
180 * Preconditioned Block conjugate gradient rQ
181 * Generalise Sebastien Birk Thesis, after Dubrulle 2001.
182 * Introduce preconditioning following Saad Ch9
183 ************************************************************************
184 * Dimensions:
185 *
186 * X,B etc... ==(Nferm x nrhs)
187 * Matrix A==(Nferm x Nferm)
188 *
189 * Nferm = Nspin x Ncolour x Ncomplex x Nlattice_site
190 * QC => Thin QR factorisation (google it)
191 *
192 * R = B-AX
193 * Z = Mi R
194 * QC = Z
195 * D = Q
196 * for k:
197 * R = AD
198 * Z = Mi R
199 * M = [D^dag R]^{-1}
200 * X = X + D M C
201 * QS = Q - Z.M
202 * D = Q + D S^dag
203 * C = S C
204 */
205 Eigen::MatrixXcd m_DZ = Eigen::MatrixXcd::Identity(nrhs,nrhs);
206 Eigen::MatrixXcd m_M = Eigen::MatrixXcd::Identity(nrhs,nrhs);
207 Eigen::MatrixXcd m_zz = Eigen::MatrixXcd::Zero(nrhs,nrhs);
208 Eigen::MatrixXcd m_rr = Eigen::MatrixXcd::Zero(nrhs,nrhs);
209
210 Eigen::MatrixXcd m_C = Eigen::MatrixXcd::Zero(nrhs,nrhs);
211 Eigen::MatrixXcd m_Cinv = Eigen::MatrixXcd::Zero(nrhs,nrhs);
212 Eigen::MatrixXcd m_S = Eigen::MatrixXcd::Zero(nrhs,nrhs);
213 Eigen::MatrixXcd m_Sinv = Eigen::MatrixXcd::Zero(nrhs,nrhs);
214
215 Eigen::MatrixXcd m_tmp = Eigen::MatrixXcd::Identity(nrhs,nrhs);
216 Eigen::MatrixXcd m_tmp1 = Eigen::MatrixXcd::Identity(nrhs,nrhs);
217
218 GridStopWatch HDCGTimer;
219
221 // x0 = Vstart -- possibly modify guess
223 Vstart(X,src);
224
226 // R = B-AX
228 for(int rhs=0;rhs<nrhs;rhs++){
229 // r0 = b -A x0
230 _FineLinop.HermOp(X[rhs],tmp[rhs]);
231 axpy (Z[rhs], -1.0,tmp[rhs], src[rhs]); // Computes R=Z=src - A X0
232 }
233
235 // Compute MZ = M1 Z = M1 B - M1 A x0
237 PcgM1(Z,MZ);
238
240 // QC = Z
242 ThinQRfact (m_zz, m_C, m_Cinv, Q, MQ, Z, MZ);
243
245 // D=MQ
247 for(int b=0;b<nrhs;b++) D[b]=MQ[b]; // LLT rotation of the MZ basis of search dirs
248
249 std::cout << GridLogMessage<<"PrecBlockCGrQ vec computed initial residual and QR fact " <<std::endl;
250
251 ProjectTimer.Reset();
252 PromoteTimer.Reset();
253 DeflateTimer.Reset();
254 CoarseTimer.Reset();
255 SmoothTimer.Reset();
256 FineTimer.Reset();
257 InsertTimer.Reset();
258
259 GridStopWatch M1Timer;
260 GridStopWatch M2Timer;
261 GridStopWatch M3Timer;
262 GridStopWatch LinalgTimer;
263 GridStopWatch InnerProdTimer;
264
265 HDCGTimer.Start();
266
267 std::vector<RealD> rn(nrhs);
268 for (int k=0;k<=MaxIterations;k++){
269
271 // Z = AD
273 M3Timer.Start();
274 for(int b=0;b<nrhs;b++) _FineLinop.HermOp(D[b], Z[b]);
275 M3Timer.Stop();
276
278 // MZ = M1 Z <==== the Multigrid preconditioner
280 M1Timer.Start();
281 PcgM1(Z,MZ);
282 M1Timer.Stop();
283
284 FineTimer.Start();
286 // M = [D^dag Z]^{-1} = (<Ddag MZ>_M)^{-1} inner prod, generalising Saad derivation of Precon CG
288 InnerProdTimer.Start();
289 _BlockCGLinalg.InnerProductMatrix(m_DZ,D,Z);
290 InnerProdTimer.Stop();
291 m_M = m_DZ.inverse();
292
294 // X = X + D MC
296 m_tmp = m_M * m_C;
297 LinalgTimer.Start();
298 _BlockCGLinalg.MaddMatrix(X,m_tmp, D,X); // D are the search directions and X takes the updates
299 LinalgTimer.Stop();
300
302 // QS = Q - M Z
303 // (MQ) S = MQ - M (M1Z)
305 LinalgTimer.Start();
306 _BlockCGLinalg.MaddMatrix(tmp ,m_M, Z, Q,-1.0);
307 _BlockCGLinalg.MaddMatrix(Mtmp,m_M,MZ,MQ,-1.0);
308 ThinQRfact (m_zz, m_S, m_Sinv, Q, MQ, tmp, Mtmp);
309 LinalgTimer.Stop();
310
312 // D = MQ + D S^dag
314 m_tmp = m_S.adjoint();
315 LinalgTimer.Start();
316 _BlockCGLinalg.MaddMatrix(D,m_tmp,D,MQ);
317 LinalgTimer.Stop();
318
320 // C = S C
322 m_C = m_S*m_C;
323
325 // convergence monitor
327 m_rr = m_C.adjoint() * m_C;
328
329 FineTimer.Stop();
330
331 RealD max_resid=0;
332 RealD rrsum=0;
333 RealD sssum=0;
334 RealD rr;
335
336 for(int b=0;b<nrhs;b++) {
337 rrsum+=real(m_rr(b,b));
338 sssum+=ssq[b];
339 rr = real(m_rr(b,b))/ssq[b];
340 if ( rr > max_resid ) max_resid = rr;
341 }
342 std::cout << GridLogMessage <<
343 "\t Prec BlockCGrQ Iteration "<<k<<" ave resid "<< std::sqrt(rrsum/sssum) << " max "<< std::sqrt(max_resid) <<std::endl;
344
345
346 if ( max_resid < Tolerance*Tolerance ) {
347
348 HDCGTimer.Stop();
349 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ converged in "<<k<<" iterations and "<<HDCGTimer.Elapsed()<<std::endl;;
350 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Linalg "<<LinalgTimer.Elapsed()<<std::endl;;
351 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : fine H "<<M3Timer.Elapsed()<<std::endl;;
352 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : prec M1 "<<M1Timer.Elapsed()<<std::endl;;
353 std::cout<<GridLogMessage<<"**** M1 breakdown:"<<std::endl;
354 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Project "<<ProjectTimer.Elapsed()<<std::endl;;
355 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Promote "<<PromoteTimer.Elapsed()<<std::endl;;
356 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Deflate "<<DeflateTimer.Elapsed()<<std::endl;;
357 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Coarse "<<CoarseTimer.Elapsed()<<std::endl;;
358 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Fine "<<FineTimer.Elapsed()<<std::endl;;
359 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Smooth "<<SmoothTimer.Elapsed()<<std::endl;;
360 std::cout<<GridLogMessage<<"HDCG: mrhs PrecBlockCGrQ : Insert "<<InsertTimer.Elapsed()<<std::endl;;
361
362 for(int rhs=0;rhs<nrhs;rhs++){
363
364 _FineLinop.HermOp(X[rhs],tmp[rhs]);
365
366 Field mytmp(grid);
367 axpy(mytmp,-1.0,src[rhs],tmp[rhs]);
368
369 RealD xnorm = sqrt(norm2(X[rhs]));
370 RealD srcnorm = sqrt(norm2(src[rhs]));
371 RealD tmpnorm = sqrt(norm2(mytmp));
372 RealD true_residual = tmpnorm/srcnorm;
373 std::cout<<GridLogMessage
374 <<"HDCG: true residual ["<<rhs<<"] is "<<true_residual
375 <<" solution "<<xnorm
376 <<" source "<<srcnorm
377 <<std::endl;
378 }
379 return;
380 }
381
382 }
383 HDCGTimer.Stop();
384 std::cout<<GridLogMessage<<"HDCG: PrecBlockCGrQ not converged "<<HDCGTimer.Elapsed()<<std::endl;
385 assert(0);
386 }
387
388 virtual void SolveSingleSystem (std::vector<Field> &src, std::vector<Field> &x)
389 {
390 std::cout << GridLogMessage<<"HDCG: mrhs fPcg starting"<<std::endl;
391 src[0].Grid()->Barrier();
392 int nrhs = src.size();
393 std::vector<RealD> f(nrhs);
394 std::vector<RealD> rtzp(nrhs);
395 std::vector<RealD> rtz(nrhs);
396 std::vector<RealD> a(nrhs);
397 std::vector<RealD> d(nrhs);
398 std::vector<RealD> b(nrhs);
399 std::vector<RealD> rptzp(nrhs);
401 // Set up history vectors
403 int mmax = 3;
404
405 std::vector<std::vector<Field> > p(nrhs); for(int r=0;r<nrhs;r++) p[r].resize(mmax,grid);
406 std::vector<std::vector<Field> > mmp(nrhs); for(int r=0;r<nrhs;r++) mmp[r].resize(mmax,grid);
407 std::vector<std::vector<RealD> > pAp(nrhs); for(int r=0;r<nrhs;r++) pAp[r].resize(mmax);
408
409 std::vector<Field> z(nrhs,grid);
410 std::vector<Field> mp (nrhs,grid);
411 std::vector<Field> r (nrhs,grid);
412 std::vector<Field> mu (nrhs,grid);
413
414 //Initial residual computation & set up
415 std::vector<RealD> src_nrm(nrhs);
416 for(int rhs=0;rhs<nrhs;rhs++) {
417 src_nrm[rhs]=norm2(src[rhs]);
418 assert(src_nrm[rhs]!=0.0);
419 }
420 std::vector<RealD> tn(nrhs);
421
422 GridStopWatch HDCGTimer;
424 // x0 = Vstart -- possibly modify guess
426 Vstart(x,src);
427
428 for(int rhs=0;rhs<nrhs;rhs++){
429 // r0 = b -A x0
430 _FineLinop.HermOp(x[rhs],mmp[rhs][0]);
431 axpy (r[rhs], -1.0,mmp[rhs][0], src[rhs]); // Recomputes r=src-Ax0
432 }
433
435 // Compute z = M1 x
437 // This needs a multiRHS version for acceleration
438 PcgM1(r,z);
439
440 std::vector<RealD> ssq(nrhs);
441 std::vector<RealD> rsq(nrhs);
442 std::vector<Field> pp(nrhs,grid);
443
444 for(int rhs=0;rhs<nrhs;rhs++){
445 rtzp[rhs] =real(innerProduct(r[rhs],z[rhs]));
446 p[rhs][0]=z[rhs];
447 ssq[rhs]=norm2(src[rhs]);
448 rsq[rhs]= ssq[rhs]*Tolerance*Tolerance;
449 // std::cout << GridLogMessage<<"mrhs HDCG: "<<rhs<<" k=0 residual "<<rtzp[rhs]<<" rsq "<<rsq[rhs]<<"\n";
450 }
451
452 ProjectTimer.Reset();
453 PromoteTimer.Reset();
454 DeflateTimer.Reset();
455 CoarseTimer.Reset();
456 SmoothTimer.Reset();
457 FineTimer.Reset();
458 InsertTimer.Reset();
459
460 GridStopWatch M1Timer;
461 GridStopWatch M2Timer;
462 GridStopWatch M3Timer;
463 GridStopWatch LinalgTimer;
464
465 HDCGTimer.Start();
466
467 std::vector<RealD> rn(nrhs);
468 for (int k=0;k<=MaxIterations;k++){
469
470 int peri_k = k % mmax;
471 int peri_kp = (k+1) % mmax;
472
473 for(int rhs=0;rhs<nrhs;rhs++){
474 rtz[rhs]=rtzp[rhs];
475 M3Timer.Start();
476 d[rhs]= PcgM3(p[rhs][peri_k],mmp[rhs][peri_k]);
477 M3Timer.Stop();
478 a[rhs] = rtz[rhs]/d[rhs];
479
480 LinalgTimer.Start();
481 // Memorise this
482 pAp[rhs][peri_k] = d[rhs];
483
484 axpy(x[rhs],a[rhs],p[rhs][peri_k],x[rhs]);
485 rn[rhs] = axpy_norm(r[rhs],-a[rhs],mmp[rhs][peri_k],r[rhs]);
486 LinalgTimer.Stop();
487 }
488
489 // Compute z = M x (for *all* RHS)
490 M1Timer.Start();
491 PcgM1(r,z);
492 M1Timer.Stop();
493
494 RealD max_rn=0.0;
495 LinalgTimer.Start();
496 for(int rhs=0;rhs<nrhs;rhs++){
497
498 rtzp[rhs] =real(innerProduct(r[rhs],z[rhs]));
499
500 // std::cout << GridLogMessage<<"HDCG::fPcg rhs"<<rhs<<" iteration "<<k<<" : inner rtzp "<<rtzp[rhs]<<"\n";
501 mu[rhs]=z[rhs];
502
503 p[rhs][peri_kp]=mu[rhs];
504
505 // Standard search direction p == z + b p
506 b[rhs] = (rtzp[rhs])/rtz[rhs];
507
508 int northog = (k>mmax-1)?(mmax-1):k; // This is the fCG-Tr(mmax-1) algorithm
509 for(int back=0; back < northog; back++){
510 int peri_back = (k-back)%mmax;
511 RealD pbApk= real(innerProduct(mmp[rhs][peri_back],p[rhs][peri_kp]));
512 RealD beta = -pbApk/pAp[rhs][peri_back];
513 axpy(p[rhs][peri_kp],beta,p[rhs][peri_back],p[rhs][peri_kp]);
514 }
515
516 RealD rrn=sqrt(rn[rhs]/ssq[rhs]);
517 RealD rtn=sqrt(rtz[rhs]/ssq[rhs]);
518 RealD rtnp=sqrt(rtzp[rhs]/ssq[rhs]);
519
520 std::cout<<GridLogMessage<<"HDCG:fPcg rhs "<<rhs<<" k= "<<k<<" residual = "<<rrn<<"\n";
521 if ( rrn > max_rn ) max_rn = rrn;
522 }
523 LinalgTimer.Stop();
524
525 // Stopping condition based on worst case
526 if ( max_rn <= Tolerance ) {
527
528 HDCGTimer.Stop();
529 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg converged in "<<k<<" iterations and "<<HDCGTimer.Elapsed()<<std::endl;;
530 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Linalg "<<LinalgTimer.Elapsed()<<std::endl;;
531 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : fine M3 "<<M3Timer.Elapsed()<<std::endl;;
532 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : prec M1 "<<M1Timer.Elapsed()<<std::endl;;
533 std::cout<<GridLogMessage<<"**** M1 breakdown:"<<std::endl;
534 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Project "<<ProjectTimer.Elapsed()<<std::endl;;
535 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Promote "<<PromoteTimer.Elapsed()<<std::endl;;
536 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Deflate "<<DeflateTimer.Elapsed()<<std::endl;;
537 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Coarse "<<CoarseTimer.Elapsed()<<std::endl;;
538 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Fine "<<FineTimer.Elapsed()<<std::endl;;
539 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Smooth "<<SmoothTimer.Elapsed()<<std::endl;;
540 std::cout<<GridLogMessage<<"HDCG: mrhs fPcg : Insert "<<InsertTimer.Elapsed()<<std::endl;;
541
542 for(int rhs=0;rhs<nrhs;rhs++){
543 _FineLinop.HermOp(x[rhs],mmp[rhs][0]);
544 Field tmp(grid);
545 axpy(tmp,-1.0,src[rhs],mmp[rhs][0]);
546
547 RealD mmpnorm = sqrt(norm2(mmp[rhs][0]));
548 RealD xnorm = sqrt(norm2(x[rhs]));
549 RealD srcnorm = sqrt(norm2(src[rhs]));
550 RealD tmpnorm = sqrt(norm2(tmp));
551 RealD true_residual = tmpnorm/srcnorm;
552 std::cout<<GridLogMessage
553 <<"HDCG: true residual ["<<rhs<<"] is "<<true_residual
554 <<" solution "<<xnorm
555 <<" source "<<srcnorm
556 <<" mmp "<<mmpnorm
557 <<std::endl;
558 }
559 return;
560 }
561
562 }
563 HDCGTimer.Stop();
564 std::cout<<GridLogMessage<<"HDCG: not converged "<<HDCGTimer.Elapsed()<<std::endl;
565 for(int rhs=0;rhs<nrhs;rhs++){
566 RealD xnorm = sqrt(norm2(x[rhs]));
567 RealD srcnorm = sqrt(norm2(src[rhs]));
568 std::cout<<GridLogMessage<<"HDCG: non-converged solution "<<xnorm<<" source "<<srcnorm<<std::endl;
569 }
570 }
571
572
573 public:
574
575 virtual void PcgM1(std::vector<Field> & in,std::vector<Field> & out) = 0;
576 virtual void Vstart(std::vector<Field> & x,std::vector<Field> & src) = 0;
577 virtual void PcgM2(const Field & in, Field & out) {
578 out=in;
579 }
580
581 virtual RealD PcgM3(const Field & p, Field & mmp){
582 RealD dd;
583 _FineLinop.HermOp(p,mmp);
584 ComplexD dot = innerProduct(p,mmp);
585 dd=real(dot);
586 return dd;
587 }
588
589};
590
591template<class Field, class CoarseField>
593{
594public:
601
602
604 Integer maxit,
605 LinearOperatorBase<Field> &FineLinop,
606 LinearFunction<Field> &Smoother,
607 LinearFunction<CoarseField> &CoarseSolverMrhs,
608 LinearFunction<CoarseField> &CoarseSolverPreciseMrhs,
611 GridBase *_coarsemrhsgrid) :
612 TwoLevelCGmrhs<Field>(tol, maxit,FineLinop,Smoother,Projector.fine_grid),
613 _CoarseSolverMrhs(CoarseSolverMrhs),
614 _CoarseSolverPreciseMrhs(CoarseSolverPreciseMrhs),
615 _Projector(Projector),
616 _Deflator(Deflator)
617 {
618 coarsegrid = Projector.coarse_grid;
619 coarsegridmrhs = _coarsemrhsgrid;// Thi could be in projector
620 };
621
622 // Override Vstart
623 virtual void Vstart(std::vector<Field> & x,std::vector<Field> & src)
624 {
625 int nrhs=x.size();
627 // Choose x_0 such that
628 // x_0 = guess + (A_ss^inv) r_s = guess + Ass_inv [src -Aguess]
629 // = [1 - Ass_inv A] Guess + Assinv src
630 // = P^T guess + Assinv src
631 // = Vstart [Tang notation]
632 // This gives:
633 // W^T (src - A x_0) = src_s - A guess_s - r_s
634 // = src_s - (A guess)_s - src_s + (A guess)_s
635 // = 0
637 std::vector<CoarseField> PleftProj(nrhs,this->coarsegrid);
638 std::vector<CoarseField> PleftMss_proj(nrhs,this->coarsegrid);
639 CoarseField PleftProjMrhs(this->coarsegridmrhs);
640 CoarseField PleftMss_projMrhs(this->coarsegridmrhs);
641
642 this->_Projector.blockProject(src,PleftProj);
643 this->_Deflator.DeflateSources(PleftProj,PleftMss_proj);
644 for(int rhs=0;rhs<nrhs;rhs++) {
645 InsertSliceFast(PleftProj[rhs],PleftProjMrhs,rhs,0);
646 InsertSliceFast(PleftMss_proj[rhs],PleftMss_projMrhs,rhs,0); // the guess
647 }
648
649 this->_CoarseSolverPreciseMrhs(PleftProjMrhs,PleftMss_projMrhs); // Ass^{-1} r_s
650
651 for(int rhs=0;rhs<nrhs;rhs++) {
652 ExtractSliceFast(PleftMss_proj[rhs],PleftMss_projMrhs,rhs,0);
653 }
654 this->_Projector.blockPromote(x,PleftMss_proj);
655 }
656
657 virtual void PcgM1(std::vector<Field> & in,std::vector<Field> & out){
658
659 int nrhs=in.size();
660
661 // [PTM+Q] in = [1 - Q A] M in + Q in = Min + Q [ in -A Min]
662 std::vector<Field> tmp(nrhs,this->grid);
663 std::vector<Field> Min(nrhs,this->grid);
664
665 std::vector<CoarseField> PleftProj(nrhs,this->coarsegrid);
666 std::vector<CoarseField> PleftMss_proj(nrhs,this->coarsegrid);
667
668 CoarseField PleftProjMrhs(this->coarsegridmrhs);
669 CoarseField PleftMss_projMrhs(this->coarsegridmrhs);
670
671 // this->rrr=in[0];
672
673#undef SMOOTHER_BLOCK_SOLVE
674#if SMOOTHER_BLOCK_SOLVE
675 this->SmoothTimer.Start();
676 this->_Smoother(in,Min);
677 this->SmoothTimer.Stop();
678#else
679 for(int rhs=0;rhs<nrhs;rhs++) {
680 this->SmoothTimer.Start();
681 this->_Smoother(in[rhs],Min[rhs]);
682 this->SmoothTimer.Stop();
683 }
684#endif
685 // this->sss=Min[0];
686
687 for(int rhs=0;rhs<nrhs;rhs++) {
688
689 this->FineTimer.Start();
690 this->_FineLinop.HermOp(Min[rhs],out[rhs]);
691 axpy(tmp[rhs],-1.0,out[rhs],in[rhs]); // resid = in - A Min
692 this->FineTimer.Stop();
693
694 }
695
696 this->ProjectTimer.Start();
697 this->_Projector.blockProject(tmp,PleftProj);
698 this->ProjectTimer.Stop();
699 this->DeflateTimer.Start();
700 this->_Deflator.DeflateSources(PleftProj,PleftMss_proj);
701 this->DeflateTimer.Stop();
702 this->InsertTimer.Start();
703 for(int rhs=0;rhs<nrhs;rhs++) {
704 InsertSliceFast(PleftProj[rhs],PleftProjMrhs,rhs,0);
705 InsertSliceFast(PleftMss_proj[rhs],PleftMss_projMrhs,rhs,0); // the guess
706 }
707 this->InsertTimer.Stop();
708
709 this->CoarseTimer.Start();
710 this->_CoarseSolverMrhs(PleftProjMrhs,PleftMss_projMrhs); // Ass^{-1} [in - A Min]_s
711 this->CoarseTimer.Stop();
712
713 this->InsertTimer.Start();
714 for(int rhs=0;rhs<nrhs;rhs++) {
715 ExtractSliceFast(PleftMss_proj[rhs],PleftMss_projMrhs,rhs,0);
716 }
717 this->InsertTimer.Stop();
718 this->PromoteTimer.Start();
719 this->_Projector.blockPromote(tmp,PleftMss_proj);// tmp= Q[in - A Min]
720 this->PromoteTimer.Stop();
721 this->FineTimer.Start();
722 // this->qqq=tmp[0];
723 for(int rhs=0;rhs<nrhs;rhs++) {
724 axpy(out[rhs],1.0,Min[rhs],tmp[rhs]); // Min+tmp
725 }
726 // this->zzz=out[0];
727 this->FineTimer.Stop();
728 }
729};
730
731
733
734
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
Lattice< vobj > real(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
void InsertSliceFast(const Lattice< vobj > &From, Lattice< vobj > &To, int slice, int orthog)
void ExtractSliceFast(Lattice< vobj > &To, const Lattice< vobj > &From, int slice, int orthog)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
double usecond(void)
Definition Timer.h:50
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
void blockProject(std::vector< Field > &fine, std::vector< Lattice< cobj > > &coarse)
void blockPromote(std::vector< Field > &fine, std::vector< Lattice< cobj > > &coarse)
void DeflateSources(std::vector< Field > &source, std::vector< Field > &guess)
MultiRHSBlockProject< Field > & _Projector
Definition AdefMrhs.h:599
GridBase * coarsegridmrhs
Definition AdefMrhs.h:596
GridBase * coarsegrid
Definition AdefMrhs.h:595
LinearFunction< CoarseField > & _CoarseSolverMrhs
Definition AdefMrhs.h:597
MultiRHSDeflation< CoarseField > & _Deflator
Definition AdefMrhs.h:600
virtual void Vstart(std::vector< Field > &x, std::vector< Field > &src)
Definition AdefMrhs.h:623
virtual void PcgM1(std::vector< Field > &in, std::vector< Field > &out)
Definition AdefMrhs.h:657
LinearFunction< CoarseField > & _CoarseSolverPreciseMrhs
Definition AdefMrhs.h:598
TwoLevelADEF2mrhs(RealD tol, Integer maxit, LinearOperatorBase< Field > &FineLinop, LinearFunction< Field > &Smoother, LinearFunction< CoarseField > &CoarseSolverMrhs, LinearFunction< CoarseField > &CoarseSolverPreciseMrhs, MultiRHSBlockProject< Field > &Projector, MultiRHSDeflation< CoarseField > &Deflator, GridBase *_coarsemrhsgrid)
Definition AdefMrhs.h:603
LinearOperatorBase< Field > & _FineLinop
Definition AdefMrhs.h:54
virtual void PcgM1(std::vector< Field > &in, std::vector< Field > &out)=0
GridStopWatch FineTimer
Definition AdefMrhs.h:62
MultiRHSBlockCGLinalg< Field > _BlockCGLinalg
Definition AdefMrhs.h:56
GridBase * grid
Definition AdefMrhs.h:51
Integer MaxIterations
Definition AdefMrhs.h:50
GridStopWatch PromoteTimer
Definition AdefMrhs.h:59
GridStopWatch DeflateTimer
Definition AdefMrhs.h:60
virtual void operator()(std::vector< Field > &src, std::vector< Field > &x)
Definition AdefMrhs.h:93
virtual RealD PcgM3(const Field &p, Field &mmp)
Definition AdefMrhs.h:581
GridStopWatch SmoothTimer
Definition AdefMrhs.h:63
GridStopWatch CoarseTimer
Definition AdefMrhs.h:61
TwoLevelCGmrhs(RealD tol, Integer maxit, LinearOperatorBase< Field > &FineLinop, LinearFunction< Field > &Smoother, GridBase *fine)
Definition AdefMrhs.h:73
LinearFunction< Field > & _Smoother
Definition AdefMrhs.h:55
GridStopWatch ProjectTimer
Definition AdefMrhs.h:58
virtual void SolveSingleSystem(std::vector< Field > &src, std::vector< Field > &x)
Definition AdefMrhs.h:388
RealD Tolerance
Definition AdefMrhs.h:49
void ThinQRfact(Eigen::MatrixXcd &m_zz, Eigen::MatrixXcd &C, Eigen::MatrixXcd &Cinv, std::vector< Field > &Q, std::vector< Field > &MQ, const std::vector< Field > &Z, const std::vector< Field > &MZ)
Definition AdefMrhs.h:118
virtual void SolvePrecBlockCG(std::vector< Field > &src, std::vector< Field > &X)
Definition AdefMrhs.h:146
virtual void Vstart(std::vector< Field > &x, std::vector< Field > &src)=0
GridStopWatch InsertTimer
Definition AdefMrhs.h:64
virtual void PcgM2(const Field &in, Field &out)
Definition AdefMrhs.h:577