Grid 0.7.0
ConjugateGradientMultiShift.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/ConjugateGradientMultiShift.h
6
7 Copyright (C) 2015
8
9Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
10Author: Peter Boyle <paboyle@ph.ed.ac.uk>
11
12 This program is free software; you can redistribute it and/or modify
13 it under the terms of the GNU General Public License as published by
14 the Free Software Foundation; either version 2 of the License, or
15 (at your option) any later version.
16
17 This program is distributed in the hope that it will be useful,
18 but WITHOUT ANY WARRANTY; without even the implied warranty of
19 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20 GNU General Public License for more details.
21
22 You should have received a copy of the GNU General Public License along
23 with this program; if not, write to the Free Software Foundation, Inc.,
24 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25
26 See the full license in the file "LICENSE" in the top level distribution directory
27*************************************************************************************/
28/* END LEGAL */
29#ifndef GRID_CONJUGATE_MULTI_SHIFT_GRADIENT_H
30#define GRID_CONJUGATE_MULTI_SHIFT_GRADIENT_H
31
33
35// Base classes for iterative processes based on operators
36// single input vec, single output vec.
38
39template<class Field>
41 public OperatorFunction<Field>
42{
43public:
44
45 using OperatorFunction<Field>::operator();
46
47 // RealD Tolerance;
49 Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
50 std::vector<int> IterationsToCompleteShift; // Iterations for this shift
53 std::vector<RealD> TrueResidualShift;
54
56 MaxIterations(maxit),
57 shifts(_shifts)
58 {
59 verbose=1;
60 IterationsToCompleteShift.resize(_shifts.order);
61 TrueResidualShift.resize(_shifts.order);
62 }
63
64 void operator() (LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
65 {
66 GridBase *grid = src.Grid();
67 int nshift = shifts.order;
68 std::vector<Field> results(nshift,grid);
69 (*this)(Linop,src,results,psi);
70 }
71 void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector<Field> &results, Field &psi)
72 {
73 int nshift = shifts.order;
74
75 (*this)(Linop,src,results);
76
77 psi = shifts.norm*src;
78 for(int i=0;i<nshift;i++){
79 psi = psi + shifts.residues[i]*results[i];
80 }
81
82 return;
83 }
84
85 void operator() (LinearOperatorBase<Field> &Linop, const Field &src, std::vector<Field> &psi)
86 {
87 GRID_TRACE("ConjugateGradientMultiShift");
88
89 GridBase *grid = src.Grid();
90
92 // Convenience references to the info stored in "MultiShiftFunction"
94 int nshift = shifts.order;
95
96 std::vector<RealD> &mass(shifts.poles); // Make references to array in "shifts"
97 std::vector<RealD> &mresidual(shifts.tolerances);
98 std::vector<RealD> alpha(nshift,1.0);
99 std::vector<Field> ps(nshift,grid);// Search directions
100
101 assert(psi.size()==nshift);
102 assert(mass.size()==nshift);
103 assert(mresidual.size()==nshift);
104
105 // remove dynamic sized arrays on stack; 2d is a pain with vector
106 std::vector<RealD> bs(nshift);
107 std::vector<RealD> rsq(nshift);
108 std::vector<std::array<RealD,2> > z(nshift);
109 std::vector<int> converged(nshift);
110
111 const int primary =0;
112
113 //Primary shift fields CG iteration
114 RealD a,b,c,d;
115 RealD cp,bp,qq; //prev
116
117 // Matrix mult fields
118 Field r(grid);
119 Field p(grid);
120 Field tmp(grid);
121 Field mmp(grid);
122
123 // Check lightest mass
124 for(int s=0;s<nshift;s++){
125 assert( mass[s]>= mass[primary] );
126 converged[s]=0;
127 }
128
129 // Wire guess to zero
130 // Residuals "r" are src
131 // First search direction "p" is also src
132 cp = norm2(src);
133
134 // Handle trivial case of zero src.
135 if( cp == 0. ){
136 for(int s=0;s<nshift;s++){
137 psi[s] = Zero();
139 TrueResidualShift[s] = 0.;
140 }
141 return;
142 }
143
144 for(int s=0;s<nshift;s++){
145 rsq[s] = cp * mresidual[s] * mresidual[s];
146 std::cout<<GridLogMessage<<"ConjugateGradientMultiShift: shift "<<s
147 <<" target resid^2 "<<rsq[s]<<std::endl;
148 ps[s] = src;
149 }
150 // r and p for primary
151 r=src;
152 p=src;
153
154 //MdagM+m[0]
155 Linop.HermOpAndNorm(p,mmp,d,qq);
156 axpy(mmp,mass[0],p,mmp);
157 RealD rn = norm2(p);
158 d += rn*mass[0];
159
160 // have verified that inner product of
161 // p and mmp is equal to d after this since
162 // the d computation is tricky
163 // qq = real(innerProduct(p,mmp));
164 // std::cout<<GridLogMessage << "debug equal ? qq "<<qq<<" d "<< d<<std::endl;
165
166 b = -cp /d;
167
168 // Set up the various shift variables
169 int iz=0;
170 z[0][1-iz] = 1.0;
171 z[0][iz] = 1.0;
172 bs[0] = b;
173 for(int s=1;s<nshift;s++){
174 z[s][1-iz] = 1.0;
175 z[s][iz] = 1.0/( 1.0 - b*(mass[s]-mass[0]));
176 bs[s] = b*z[s][iz];
177 }
178
179 // r += b[0] A.p[0]
180 // c= norm(r)
181 c=axpy_norm(r,b,mmp,r);
182
183 for(int s=0;s<nshift;s++) {
184 axpby(psi[s],0.,-bs[s]*alpha[s],src,src);
185 }
186
187 std::cout << GridLogIterative << "ConjugateGradientMultiShift: initial rn (|src|^2) =" << rn << " qq (|MdagM src|^2) =" << qq << " d ( dot(src, [MdagM + m_0]src) ) =" << d << " c=" << c << std::endl;
188
189
191 // Timers
193 GridStopWatch AXPYTimer;
194 GridStopWatch ShiftTimer;
195 GridStopWatch QRTimer;
196 GridStopWatch MatrixTimer;
197 GridStopWatch SolverTimer;
198 SolverTimer.Start();
199
200 // Iteration loop
201 int k;
202
203 for (k=1;k<=MaxIterations;k++){
204
205 a = c /cp;
206 AXPYTimer.Start();
207 axpy(p,a,p,r);
208 AXPYTimer.Stop();
209
210 // Note to self - direction ps is iterated seperately
211 // for each shift. Does not appear to have any scope
212 // for avoiding linear algebra in "single" case.
213 //
214 // However SAME r is used. Could load "r" and update
215 // ALL ps[s]. 2/3 Bandwidth saving
216 // New Kernel: Load r, vector of coeffs, vector of pointers ps
217 AXPYTimer.Start();
218 for(int s=0;s<nshift;s++){
219 if ( ! converged[s] ) {
220 if (s==0){
221 axpy(ps[s],a,ps[s],r);
222 } else{
223 RealD as =a *z[s][iz]*bs[s] /(z[s][1-iz]*b);
224 axpby(ps[s],z[s][iz],as,r,ps[s]);
225 }
226 }
227 }
228 AXPYTimer.Stop();
229
230 cp=c;
231 MatrixTimer.Start();
232 //Linop.HermOpAndNorm(p,mmp,d,qq); // d is used
233 // The below is faster on KNL
234 Linop.HermOp(p,mmp);
235 d=real(innerProduct(p,mmp));
236
237 MatrixTimer.Stop();
238
239 AXPYTimer.Start();
240 axpy(mmp,mass[0],p,mmp);
241 AXPYTimer.Stop();
242 RealD rn = norm2(p);
243 d += rn*mass[0];
244
245 bp=b;
246 b=-cp/d;
247
248 AXPYTimer.Start();
249 c=axpy_norm(r,b,mmp,r);
250 AXPYTimer.Stop();
251
252 // Toggle the recurrence history
253 bs[0] = b;
254 iz = 1-iz;
255 ShiftTimer.Start();
256 for(int s=1;s<nshift;s++){
257 if((!converged[s])){
258 RealD z0 = z[s][1-iz];
259 RealD z1 = z[s][iz];
260 z[s][iz] = z0*z1*bp
261 / (b*a*(z1-z0) + z1*bp*(1- (mass[s]-mass[0])*b));
262 bs[s] = b*z[s][iz]/z0; // NB sign rel to Mike
263 }
264 }
265 ShiftTimer.Stop();
266
267 for(int s=0;s<nshift;s++){
268 int ss = s;
269 // Scope for optimisation here in case of "single".
270 // Could load psi[0] and pull all ps[s] in.
271 // if ( single ) ss=primary;
272 // Bandwith saving in single case is Ls * 3 -> 2+Ls, so ~ 3x saving
273 // Pipelined CG gain:
274 //
275 // New Kernel: Load r, vector of coeffs, vector of pointers ps
276 // New Kernel: Load psi[0], vector of coeffs, vector of pointers ps
277 // If can predict the coefficient bs then we can fuse these and avoid write reread cyce
278 // on ps[s].
279 // Before: 3 x npole + 3 x npole
280 // After : 2 x npole (ps[s]) => 3x speed up of multishift CG.
281
282 if( (!converged[s]) ) {
283 axpy(psi[ss],-bs[s]*alpha[s],ps[s],psi[ss]);
284 }
285 }
286
287 // Convergence checks
288 int all_converged = 1;
289 for(int s=0;s<nshift;s++){
290
291 if ( (!converged[s]) ){
293
294 RealD css = c * z[s][iz]* z[s][iz];
295
296 if(css<rsq[s]){
297 if ( ! converged[s] )
298 std::cout<<GridLogMessage<<"ConjugateGradientMultiShift k="<<k<<" Shift "<<s<<" has converged"<<std::endl;
299 converged[s]=1;
300 } else {
301 all_converged=0;
302 }
303
304 }
305 }
306
307 if ( all_converged ){
308
309 SolverTimer.Stop();
310
311
312 std::cout<<GridLogMessage<< "CGMultiShift: All shifts have converged iteration "<<k<<std::endl;
313 std::cout<<GridLogMessage<< "CGMultiShift: Checking solutions"<<std::endl;
314
315 // Check answers
316 for(int s=0; s < nshift; s++) {
317 Linop.HermOpAndNorm(psi[s],mmp,d,qq);
318 axpy(tmp,mass[s],psi[s],mmp);
319 axpy(r,-alpha[s],src,tmp);
320 RealD rn = norm2(r);
321 RealD cn = norm2(src);
322 TrueResidualShift[s] = std::sqrt(rn/cn);
323 std::cout<<GridLogMessage<<"CGMultiShift: shift["<<s<<"] true residual "<< TrueResidualShift[s] <<std::endl;
324 }
325
326 std::cout << GridLogMessage << "Time Breakdown "<<std::endl;
327 std::cout << GridLogMessage << "\tElapsed " << SolverTimer.Elapsed() <<std::endl;
328 std::cout << GridLogMessage << "\tAXPY " << AXPYTimer.Elapsed() <<std::endl;
329 std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
330 std::cout << GridLogMessage << "\tShift " << ShiftTimer.Elapsed() <<std::endl;
331
333
334 return;
335 }
336
337
338 }
339 // ugly hack
340 std::cout<<GridLogMessage<<"CG multi shift did not converge"<<std::endl;
341 // assert(0);
342 }
343
344};
346#endif
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpby(Lattice< vobj > &ret, sobj a, sobj b, const Lattice< vobj > &x, const Lattice< vobj > &y)
Lattice< vobj > real(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
double RealD
Definition Simd.h:61
#define GRID_TRACE(name)
Definition Tracing.h:68
void operator()(LinearOperatorBase< Field > &Linop, const Field &src, Field &psi)
ConjugateGradientMultiShift(Integer maxit, const MultiShiftFunction &_shifts)
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
virtual void HermOp(const Field &in, Field &out)=0
virtual void HermOpAndNorm(const Field &in, Field &out, RealD &n1, RealD &n2)=0
Definition Simd.h:194