Grid 0.7.0
ConjugateGradientMultiShiftMixedPrec.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/ConjugateGradientMultiShift.h
6
7 Copyright (C) 2015
8
9Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
10Author: Peter Boyle <paboyle@ph.ed.ac.uk>
11Author: Christopher Kelly <ckelly@bnl.gov>
12
13 This program is free software; you can redistribute it and/or modify
14 it under the terms of the GNU General Public License as published by
15 the Free Software Foundation; either version 2 of the License, or
16 (at your option) any later version.
17
18 This program is distributed in the hope that it will be useful,
19 but WITHOUT ANY WARRANTY; without even the implied warranty of
20 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 GNU General Public License for more details.
22
23 You should have received a copy of the GNU General Public License along
24 with this program; if not, write to the Free Software Foundation, Inc.,
25 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26
27 See the full license in the file "LICENSE" in the top level distribution directory
28*************************************************************************************/
29/* END LEGAL */
30#ifndef GRID_CONJUGATE_GRADIENT_MULTI_SHIFT_MIXEDPREC_H
31#define GRID_CONJUGATE_GRADIENT_MULTI_SHIFT_MIXEDPREC_H
32
34
35//CK 2020: A variant of the multi-shift conjugate gradient with the matrix multiplication in single precision.
36//The residual is stored in single precision, but the search directions and solution are stored in double precision.
37//Every update_freq iterations the residual is corrected in double precision.
38
39//For safety the a final regular CG is applied to clean up if necessary
40
41//Linop to add shift to input linop, used in cleanup CG
43template<typename Field>
44class ShiftedLinop: public LinearOperatorBase<Field>{
45public:
48
49 ShiftedLinop(LinearOperatorBase<Field> &_linop_base, RealD _shift): linop_base(_linop_base), shift(_shift){}
50
51 void OpDiag (const Field &in, Field &out){ assert(0); }
52 void OpDir (const Field &in, Field &out,int dir,int disp){ assert(0); }
53 void OpDirAll (const Field &in, std::vector<Field> &out){ assert(0); }
54
55 void Op (const Field &in, Field &out){ assert(0); }
56 void AdjOp (const Field &in, Field &out){ assert(0); }
57
58 void HermOp(const Field &in, Field &out){
59 linop_base.HermOp(in, out);
60 axpy(out, shift, in, out);
61 }
62
63 void HermOpAndNorm(const Field &in, Field &out,RealD &n1,RealD &n2){
64 HermOp(in,out);
65 ComplexD dot = innerProduct(in,out);
66 n1=real(dot);
67 n2=norm2(out);
68 }
69};
70};
71
72
73template<class FieldD, class FieldF,
74 typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0,
75 typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
77 public OperatorFunction<FieldD>
78{
79public:
80
81 using OperatorFunction<FieldD>::operator();
82
86 Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
87 std::vector<int> IterationsToCompleteShift; // Iterations for this shift
90 std::vector<RealD> TrueResidualShift;
91
92 int ReliableUpdateFreq; //number of iterations between reliable updates
93
94 GridBase* SinglePrecGrid; //Grid for single-precision fields
96
98 GridBase* _SinglePrecGrid, LinearOperatorBase<FieldF> &_Linop_f,
99 int _ReliableUpdateFreq) :
100 MaxIterationsMshift(maxit), shifts(_shifts), SinglePrecGrid(_SinglePrecGrid), Linop_f(_Linop_f), ReliableUpdateFreq(_ReliableUpdateFreq),
101 MaxIterations(20000)
102 {
103 verbose=1;
104 IterationsToCompleteShift.resize(_shifts.order);
105 TrueResidualShift.resize(_shifts.order);
106 }
107
108 void operator() (LinearOperatorBase<FieldD> &Linop, const FieldD &src, FieldD &psi)
109 {
110 GridBase *grid = src.Grid();
111 int nshift = shifts.order;
112 std::vector<FieldD> results(nshift,grid);
113 (*this)(Linop,src,results,psi);
114 }
115 void operator() (LinearOperatorBase<FieldD> &Linop, const FieldD &src, std::vector<FieldD> &results, FieldD &psi)
116 {
117 int nshift = shifts.order;
118
119 (*this)(Linop,src,results);
120
121 psi = shifts.norm*src;
122 for(int i=0;i<nshift;i++){
123 psi = psi + shifts.residues[i]*results[i];
124 }
125
126 return;
127 }
128
129 void operator() (LinearOperatorBase<FieldD> &Linop_d, const FieldD &src_d, std::vector<FieldD> &psi_d)
130 {
131 GRID_TRACE("ConjugateGradientMultiShiftMixedPrec");
132 GridBase *DoublePrecGrid = src_d.Grid();
133
134 precisionChangeWorkspace pc_wk_s_to_d(DoublePrecGrid,SinglePrecGrid);
135 precisionChangeWorkspace pc_wk_d_to_s(SinglePrecGrid,DoublePrecGrid);
136
138 // Convenience references to the info stored in "MultiShiftFunction"
140 int nshift = shifts.order;
141
142 std::vector<RealD> &mass(shifts.poles); // Make references to array in "shifts"
143 std::vector<RealD> &mresidual(shifts.tolerances);
144 std::vector<RealD> alpha(nshift,1.0);
145
146 //Double precision search directions
147 FieldD p_d(DoublePrecGrid);
148 std::vector<FieldD> ps_d(nshift, DoublePrecGrid);// Search directions (double precision)
149
150 FieldD tmp_d(DoublePrecGrid);
151 FieldD r_d(DoublePrecGrid);
152 FieldD mmp_d(DoublePrecGrid);
153
154 assert(psi_d.size()==nshift);
155 assert(mass.size()==nshift);
156 assert(mresidual.size()==nshift);
157
158 // dynamic sized arrays on stack; 2d is a pain with vector
159 std::vector<RealD> bs(nshift);
160 std::vector<RealD> rsq(nshift);
161 std::vector<RealD> rsqf(nshift);
162 std::vector<std::array<RealD,2> > z(nshift);
163 std::vector<int> converged(nshift);
164
165 const int primary =0;
166
167 //Primary shift fields CG iteration
168 RealD a,b,c,d;
169 RealD cp,bp,qq; //prev
170
171 // Matrix mult fields
172 FieldF p_f(SinglePrecGrid);
173 FieldF mmp_f(SinglePrecGrid);
174
175 // Check lightest mass
176 for(int s=0;s<nshift;s++){
177 assert( mass[s]>= mass[primary] );
178 converged[s]=0;
179 }
180
181 // Wire guess to zero
182 // Residuals "r" are src
183 // First search direction "p" is also src
184 cp = norm2(src_d);
185
186 // Handle trivial case of zero src.
187 if( cp == 0. ){
188 for(int s=0;s<nshift;s++){
189 psi_d[s] = Zero();
191 TrueResidualShift[s] = 0.;
192 }
193 return;
194 }
195
196 for(int s=0;s<nshift;s++){
197 rsq[s] = cp * mresidual[s] * mresidual[s];
198 rsqf[s] =rsq[s];
199 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec: shift "<< s <<" target resid "<<rsq[s]<<std::endl;
200 ps_d[s] = src_d;
201 }
202 // r and p for primary
203 p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys
204 r_d = p_d;
205
206 //MdagM+m[0]
207 precisionChange(p_f, p_d, pc_wk_d_to_s);
208
209 Linop_f.HermOpAndNorm(p_f,mmp_f,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
210 precisionChange(tmp_d, mmp_f, pc_wk_s_to_d);
211 Linop_d.HermOpAndNorm(p_d,mmp_d,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
212 tmp_d = tmp_d - mmp_d;
213 std::cout << " Testing operators match "<<norm2(mmp_d)<<" f "<<norm2(mmp_f)<<" diff "<< norm2(tmp_d)<<std::endl;
214 assert(norm2(tmp_d)< 1.0);
215
216 axpy(mmp_d,mass[0],p_d,mmp_d);
217 RealD rn = norm2(p_d);
218 d += rn*mass[0];
219
220 b = -cp /d;
221
222 // Set up the various shift variables
223 int iz=0;
224 z[0][1-iz] = 1.0;
225 z[0][iz] = 1.0;
226 bs[0] = b;
227 for(int s=1;s<nshift;s++){
228 z[s][1-iz] = 1.0;
229 z[s][iz] = 1.0/( 1.0 - b*(mass[s]-mass[0]));
230 bs[s] = b*z[s][iz];
231 }
232
233 // r += b[0] A.p[0]
234 // c= norm(r)
235 c=axpy_norm(r_d,b,mmp_d,r_d);
236
237 for(int s=0;s<nshift;s++) {
238 axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
239 }
240
242 // Timers
244 GridStopWatch AXPYTimer, ShiftTimer, QRTimer, MatrixTimer, SolverTimer, PrecChangeTimer, CleanupTimer;
245
246 SolverTimer.Start();
247
248 // Iteration loop
249 int k;
250
251 for (k=1;k<=MaxIterationsMshift;k++){
252
253 a = c /cp;
254 AXPYTimer.Start();
255 axpy(p_d,a,p_d,r_d);
256
257 for(int s=0;s<nshift;s++){
258 if ( ! converged[s] ) {
259 if (s==0){
260 axpy(ps_d[s],a,ps_d[s],r_d);
261 } else{
262 RealD as =a *z[s][iz]*bs[s] /(z[s][1-iz]*b);
263 axpby(ps_d[s],z[s][iz],as,r_d,ps_d[s]);
264 }
265 }
266 }
267 AXPYTimer.Stop();
268
269 PrecChangeTimer.Start();
270 precisionChange(p_f, p_d, pc_wk_d_to_s); //get back single prec search direction for linop
271 PrecChangeTimer.Stop();
272
273 cp=c;
274 MatrixTimer.Start();
275 Linop_f.HermOp(p_f,mmp_f);
276 MatrixTimer.Stop();
277
278 PrecChangeTimer.Start();
279 precisionChange(mmp_d, mmp_f, pc_wk_s_to_d); // From Float to Double
280 PrecChangeTimer.Stop();
281
282 AXPYTimer.Start();
283 d=real(innerProduct(p_d,mmp_d));
284 axpy(mmp_d,mass[0],p_d,mmp_d);
285 AXPYTimer.Stop();
286 RealD rn = norm2(p_d);
287 d += rn*mass[0];
288
289 bp=b;
290 b=-cp/d;
291
292 // Toggle the recurrence history
293 bs[0] = b;
294 iz = 1-iz;
295 ShiftTimer.Start();
296 for(int s=1;s<nshift;s++){
297 if((!converged[s])){
298 RealD z0 = z[s][1-iz];
299 RealD z1 = z[s][iz];
300 z[s][iz] = z0*z1*bp
301 / (b*a*(z1-z0) + z1*bp*(1- (mass[s]-mass[0])*b));
302 bs[s] = b*z[s][iz]/z0; // NB sign rel to Mike
303 }
304 }
305 ShiftTimer.Stop();
306
307 //Update double precision solutions
308 AXPYTimer.Start();
309 for(int s=0;s<nshift;s++){
310 int ss = s;
311 if( (!converged[s]) ) {
312 axpy(psi_d[ss],-bs[s]*alpha[s],ps_d[s],psi_d[ss]);
313 }
314 }
315
316 //Perform reliable update if necessary; otherwise update residual from single-prec mmp
317 c = axpy_norm(r_d,b,mmp_d,r_d);
318
319 AXPYTimer.Stop();
320
321 if(k % ReliableUpdateFreq == 0){
322 RealD c_old = c;
323 //Replace r with true residual
324 MatrixTimer.Start();
325 Linop_d.HermOp(psi_d[0],mmp_d);
326 MatrixTimer.Stop();
327
328 AXPYTimer.Start();
329 axpy(mmp_d,mass[0],psi_d[0],mmp_d);
330
331 c = axpy_norm(r_d, -1.0, mmp_d, src_d);
332 AXPYTimer.Stop();
333
334 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<< ", replaced |r|^2 = "<<c_old <<" with |r|^2 = "<<c<<std::endl;
335 }
336
337 // Convergence checks
338 int all_converged = 1;
339 for(int s=0;s<nshift;s++){
340
341 if ( (!converged[s]) ){
343
344 RealD css = c * z[s][iz]* z[s][iz];
345
346 if(css<rsqf[s]){
347 if ( ! converged[s] )
348 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec k="<<k<<" Shift "<<s<<" has converged"<<std::endl;
349 converged[s]=1;
350 } else {
351 all_converged=0;
352 }
353
354 }
355 }
356
357 if ( all_converged || k == MaxIterationsMshift-1){
358
359 SolverTimer.Stop();
360
361 if ( all_converged ){
362 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrec: All shifts have converged iteration "<<k<<std::endl;
363 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrec: Checking solutions"<<std::endl;
364 } else {
365 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrec: Not all shifts have converged iteration "<<k<<std::endl;
366 }
367
368 // Check answers
369 for(int s=0; s < nshift; s++) {
370 Linop_d.HermOpAndNorm(psi_d[s],mmp_d,d,qq);
371 axpy(tmp_d,mass[s],psi_d[s],mmp_d);
372 axpy(r_d,-alpha[s],src_d,tmp_d);
373 RealD rn = norm2(r_d);
374 RealD cn = norm2(src_d);
375 TrueResidualShift[s] = std::sqrt(rn/cn);
376 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec: shift["<<s<<"] true residual "<< TrueResidualShift[s] << " target " << mresidual[s] << std::endl;
377
378 //If we have not reached the desired tolerance, do a (mixed precision) CG cleanup
379 if(rn >= rsq[s]){
380 CleanupTimer.Start();
381 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrec: performing cleanup step for shift " << s << std::endl;
382
383 //Setup linear operators for final cleanup
386
387 MixedPrecisionConjugateGradient<FieldD,FieldF> cg(mresidual[s], MaxIterations, MaxIterations, SinglePrecGrid, Linop_shift_f, Linop_shift_d);
388 cg(src_d, psi_d[s]);
389
391 CleanupTimer.Stop();
392 }
393 }
394
395 std::cout << GridLogMessage << "ConjugateGradientMultiShiftMixedPrec: Time Breakdown for body"<<std::endl;
396 std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
397 std::cout << GridLogMessage << "\t\tAXPY " << AXPYTimer.Elapsed() <<std::endl;
398 std::cout << GridLogMessage << "\t\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
399 std::cout << GridLogMessage << "\t\tShift " << ShiftTimer.Elapsed() <<std::endl;
400 std::cout << GridLogMessage << "\t\tPrecision Change " << PrecChangeTimer.Elapsed() <<std::endl;
401 std::cout << GridLogMessage << "\tFinal Cleanup " << CleanupTimer.Elapsed() <<std::endl;
402 std::cout << GridLogMessage << "\tSolver+Cleanup " << SolverTimer.Elapsed() + CleanupTimer.Elapsed() << std::endl;
403
405
406 return;
407 }
408
409 }
410 std::cout<<GridLogMessage<<"CG multi shift did not converge"<<std::endl;
411 assert(0);
412 }
413
414};
416#endif
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpby(Lattice< vobj > &ret, sobj a, sobj b, const Lattice< vobj > &x, const Lattice< vobj > &y)
Lattice< vobj > real(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
void precisionChange(Lattice< VobjOut > &out, const Lattice< VobjIn > &in, const precisionChangeWorkspace &workspace)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
#define GRID_TRACE(name)
Definition Tracing.h:68
ShiftedLinop(LinearOperatorBase< Field > &_linop_base, RealD _shift)
void HermOpAndNorm(const Field &in, Field &out, RealD &n1, RealD &n2)
ConjugateGradientMultiShiftMixedPrec(Integer maxit, const MultiShiftFunction &_shifts, GridBase *_SinglePrecGrid, LinearOperatorBase< FieldF > &_Linop_f, int _ReliableUpdateFreq)
void operator()(LinearOperatorBase< FieldD > &Linop, const FieldD &src, FieldD &psi)
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
virtual void HermOp(const Field &in, Field &out)=0
virtual void HermOpAndNorm(const Field &in, Field &out, RealD &n1, RealD &n2)=0
Definition Simd.h:194