Grid 0.7.0
ConjugateGradientMultiShiftCleanup.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/ConjugateGradientMultiShift.h
6
7 Copyright (C) 2015
8
9Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
10Author: Peter Boyle <paboyle@ph.ed.ac.uk>
11Author: Christopher Kelly <ckelly@bnl.gov>
12
13 This program is free software; you can redistribute it and/or modify
14 it under the terms of the GNU General Public License as published by
15 the Free Software Foundation; either version 2 of the License, or
16 (at your option) any later version.
17
18 This program is distributed in the hope that it will be useful,
19 but WITHOUT ANY WARRANTY; without even the implied warranty of
20 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 GNU General Public License for more details.
22
23 You should have received a copy of the GNU General Public License along
24 with this program; if not, write to the Free Software Foundation, Inc.,
25 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26
27 See the full license in the file "LICENSE" in the top level distribution directory
28*************************************************************************************/
29/* END LEGAL */
30#pragma once
31
33
34//CK 2020: A variant of the multi-shift conjugate gradient with the matrix multiplication in single precision.
35//The residual is stored in single precision, but the search directions and solution are stored in double precision.
36//Every update_freq iterations the residual is corrected in double precision.
37//For safety the a final regular CG is applied to clean up if necessary
38
39//PB Pure single, then double fixup
40
41template<class FieldD, class FieldF,
42 typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0,
43 typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
45 public OperatorFunction<FieldD>
46{
47public:
48
49 using OperatorFunction<FieldD>::operator();
50
54 Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
55 std::vector<int> IterationsToCompleteShift; // Iterations for this shift
58 std::vector<RealD> TrueResidualShift;
59
60 int ReliableUpdateFreq; //number of iterations between reliable updates
61
62 GridBase* SinglePrecGrid; //Grid for single-precision fields
64
66 GridBase* _SinglePrecGrid, LinearOperatorBase<FieldF> &_Linop_f,
67 int _ReliableUpdateFreq) :
68 MaxIterationsMshift(maxit), shifts(_shifts), SinglePrecGrid(_SinglePrecGrid), Linop_f(_Linop_f), ReliableUpdateFreq(_ReliableUpdateFreq),
69 MaxIterations(20000)
70 {
71 verbose=1;
72 IterationsToCompleteShift.resize(_shifts.order);
73 TrueResidualShift.resize(_shifts.order);
74 }
75
76 void operator() (LinearOperatorBase<FieldD> &Linop, const FieldD &src, FieldD &psi)
77 {
78 GridBase *grid = src.Grid();
79 int nshift = shifts.order;
80 std::vector<FieldD> results(nshift,grid);
81 (*this)(Linop,src,results,psi);
82 }
83 void operator() (LinearOperatorBase<FieldD> &Linop, const FieldD &src, std::vector<FieldD> &results, FieldD &psi)
84 {
85 int nshift = shifts.order;
86
87 (*this)(Linop,src,results);
88
89 psi = shifts.norm*src;
90 for(int i=0;i<nshift;i++){
91 psi = psi + shifts.residues[i]*results[i];
92 }
93
94 return;
95 }
96
97 void operator() (LinearOperatorBase<FieldD> &Linop_d, const FieldD &src_d, std::vector<FieldD> &psi_d)
98 {
99 GRID_TRACE("ConjugateGradientMultiShiftMixedPrecCleanup");
100 GridBase *DoublePrecGrid = src_d.Grid();
101
103 // Convenience references to the info stored in "MultiShiftFunction"
105 int nshift = shifts.order;
106
107 std::vector<RealD> &mass(shifts.poles); // Make references to array in "shifts"
108 std::vector<RealD> &mresidual(shifts.tolerances);
109 std::vector<RealD> alpha(nshift,1.0);
110
111 //Double precision search directions
112 FieldD p_d(DoublePrecGrid);
113 std::vector<FieldF> ps_f (nshift, SinglePrecGrid);// Search directions (single precision)
114 std::vector<FieldF> psi_f(nshift, SinglePrecGrid);// solutions (single precision)
115
116 FieldD tmp_d(DoublePrecGrid);
117 FieldD r_d(DoublePrecGrid);
118 FieldF r_f(SinglePrecGrid);
119 FieldD mmp_d(DoublePrecGrid);
120
121 assert(psi_d.size()==nshift);
122 assert(mass.size()==nshift);
123 assert(mresidual.size()==nshift);
124
125 // dynamic sized arrays on stack; 2d is a pain with vector
126 std::vector<RealD> bs(nshift);
127 std::vector<RealD> rsq(nshift);
128 std::vector<RealD> rsqf(nshift);
129 std::vector<std::array<RealD,2> > z(nshift);
130 std::vector<int> converged(nshift);
131
132 const int primary =0;
133
134 //Primary shift fields CG iteration
135 RealD a,b,c,d;
136 RealD cp,bp,qq; //prev
137
138 // Matrix mult fields
139 FieldF p_f(SinglePrecGrid);
140 FieldF mmp_f(SinglePrecGrid);
141
142 // Check lightest mass
143 for(int s=0;s<nshift;s++){
144 assert( mass[s]>= mass[primary] );
145 converged[s]=0;
146 }
147
148 // Wire guess to zero
149 // Residuals "r" are src
150 // First search direction "p" is also src
151 cp = norm2(src_d);
152
153 // Handle trivial case of zero src.
154 if( cp == 0. ){
155 for(int s=0;s<nshift;s++){
156 psi_d[s] = Zero();
157 psi_f[s] = Zero();
159 TrueResidualShift[s] = 0.;
160 }
161 return;
162 }
163
164 for(int s=0;s<nshift;s++){
165 rsq[s] = cp * mresidual[s] * mresidual[s];
166 rsqf[s] =rsq[s];
167 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrecCleanup: shift "<< s <<" target resid "<<rsq[s]<<std::endl;
168 // ps_d[s] = src_d;
169 precisionChange(ps_f[s],src_d);
170 }
171 // r and p for primary
172 p_d = src_d; //primary copy --- make this a reference to ps_d to save axpys
173 r_d = p_d;
174
175 //MdagM+m[0]
176 precisionChange(p_f,p_d);
177 Linop_f.HermOpAndNorm(p_f,mmp_f,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
178 precisionChange(tmp_d,mmp_f);
179 Linop_d.HermOpAndNorm(p_d,mmp_d,d,qq); // mmp = MdagM p d=real(dot(p, mmp)), qq=norm2(mmp)
180 tmp_d = tmp_d - mmp_d;
181 std::cout << " Testing operators match "<<norm2(mmp_d)<<" f "<<norm2(mmp_f)<<" diff "<< norm2(tmp_d)<<std::endl;
182 // assert(norm2(tmp_d)< 1.0e-4);
183
184 axpy(mmp_d,mass[0],p_d,mmp_d);
185 RealD rn = norm2(p_d);
186 d += rn*mass[0];
187
188 b = -cp /d;
189
190 // Set up the various shift variables
191 int iz=0;
192 z[0][1-iz] = 1.0;
193 z[0][iz] = 1.0;
194 bs[0] = b;
195 for(int s=1;s<nshift;s++){
196 z[s][1-iz] = 1.0;
197 z[s][iz] = 1.0/( 1.0 - b*(mass[s]-mass[0]));
198 bs[s] = b*z[s][iz];
199 }
200
201 // r += b[0] A.p[0]
202 // c= norm(r)
203 c=axpy_norm(r_d,b,mmp_d,r_d);
204
205 for(int s=0;s<nshift;s++) {
206 axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
207 precisionChange(psi_f[s],psi_d[s]);
208 }
209
211 // Timers
213 GridStopWatch AXPYTimer, ShiftTimer, QRTimer, MatrixTimer, SolverTimer, PrecChangeTimer, CleanupTimer;
214
215 SolverTimer.Start();
216
217 // Iteration loop
218 int k;
219
220 for (k=1;k<=MaxIterationsMshift;k++){
221
222 a = c /cp;
223 AXPYTimer.Start();
224 axpy(p_d,a,p_d,r_d);
225 AXPYTimer.Stop();
226
227 PrecChangeTimer.Start();
228 precisionChange(r_f, r_d);
229 PrecChangeTimer.Stop();
230
231 AXPYTimer.Start();
232 for(int s=0;s<nshift;s++){
233 if ( ! converged[s] ) {
234 if (s==0){
235 axpy(ps_f[s],a,ps_f[s],r_f);
236 } else{
237 RealD as =a *z[s][iz]*bs[s] /(z[s][1-iz]*b);
238 axpby(ps_f[s],z[s][iz],as,r_f,ps_f[s]);
239 }
240 }
241 }
242 AXPYTimer.Stop();
243
244 cp=c;
245 PrecChangeTimer.Start();
246 precisionChange(p_f, p_d); //get back single prec search direction for linop
247 PrecChangeTimer.Stop();
248 MatrixTimer.Start();
249 Linop_f.HermOp(p_f,mmp_f);
250 MatrixTimer.Stop();
251 PrecChangeTimer.Start();
252 precisionChange(mmp_d, mmp_f); // From Float to Double
253 PrecChangeTimer.Stop();
254
255 d=real(innerProduct(p_d,mmp_d));
256 axpy(mmp_d,mass[0],p_d,mmp_d);
257 RealD rn = norm2(p_d);
258 d += rn*mass[0];
259
260 bp=b;
261 b=-cp/d;
262
263 // Toggle the recurrence history
264 bs[0] = b;
265 iz = 1-iz;
266 ShiftTimer.Start();
267 for(int s=1;s<nshift;s++){
268 if((!converged[s])){
269 RealD z0 = z[s][1-iz];
270 RealD z1 = z[s][iz];
271 z[s][iz] = z0*z1*bp
272 / (b*a*(z1-z0) + z1*bp*(1- (mass[s]-mass[0])*b));
273 bs[s] = b*z[s][iz]/z0; // NB sign rel to Mike
274 }
275 }
276 ShiftTimer.Stop();
277
278 //Update single precision solutions
279 AXPYTimer.Start();
280 for(int s=0;s<nshift;s++){
281 int ss = s;
282 if( (!converged[s]) ) {
283 axpy(psi_f[ss],-bs[s]*alpha[s],ps_f[s],psi_f[ss]);
284 }
285 }
286 c = axpy_norm(r_d,b,mmp_d,r_d);
287 AXPYTimer.Stop();
288
289 // Convergence checks
290 int all_converged = 1;
291 for(int s=0;s<nshift;s++){
292
293 if ( (!converged[s]) ){
295
296 RealD css = c * z[s][iz]* z[s][iz];
297
298 if(css<rsqf[s]){
299 if ( ! converged[s] )
300 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrecCleanup k="<<k<<" Shift "<<s<<" has converged"<<std::endl;
301 converged[s]=1;
302 } else {
303 all_converged=0;
304 }
305
306 }
307 }
308
309 if ( all_converged || k == MaxIterationsMshift-1){
310
311 SolverTimer.Stop();
312
313 for(int s=0;s<nshift;s++){
314 precisionChange(psi_d[s],psi_f[s]);
315 }
316
317
318 if ( all_converged ){
319 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrecCleanup: All shifts have converged iteration "<<k<<std::endl;
320 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrecCleanup: Checking solutions"<<std::endl;
321 } else {
322 std::cout<<GridLogMessage<< "ConjugateGradientMultiShiftMixedPrecCleanup: Not all shifts have converged iteration "<<k<<std::endl;
323 }
324
325 // Check answers
326 for(int s=0; s < nshift; s++) {
327 Linop_d.HermOpAndNorm(psi_d[s],mmp_d,d,qq);
328 axpy(tmp_d,mass[s],psi_d[s],mmp_d);
329 axpy(r_d,-alpha[s],src_d,tmp_d);
330 RealD rn = norm2(r_d);
331 RealD cn = norm2(src_d);
332 TrueResidualShift[s] = std::sqrt(rn/cn);
333 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrecCleanup: shift["<<s<<"] true residual "<< TrueResidualShift[s] << " target " << mresidual[s] << std::endl;
334
335 //If we have not reached the desired tolerance, do a (mixed precision) CG cleanup
336 if(rn >= rsq[s]){
337 CleanupTimer.Start();
338 std::cout<<GridLogMessage<<"ConjugateGradientMultiShiftMixedPrecCleanup: performing cleanup step for shift " << s << std::endl;
339
340 //Setup linear operators for final cleanup
343
344 MixedPrecisionConjugateGradient<FieldD,FieldF> cg(mresidual[s], MaxIterations, MaxIterations, SinglePrecGrid, Linop_shift_f, Linop_shift_d);
345 cg(src_d, psi_d[s]);
346
348 CleanupTimer.Stop();
349 }
350 }
351
352 std::cout << GridLogMessage << "ConjugateGradientMultiShiftMixedPrecCleanup: Time Breakdown for body"<<std::endl;
353 std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
354 std::cout << GridLogMessage << "\t\tAXPY " << AXPYTimer.Elapsed() <<std::endl;
355 std::cout << GridLogMessage << "\t\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
356 std::cout << GridLogMessage << "\t\tShift " << ShiftTimer.Elapsed() <<std::endl;
357 std::cout << GridLogMessage << "\t\tPrecision Change " << PrecChangeTimer.Elapsed() <<std::endl;
358 std::cout << GridLogMessage << "\tFinal Cleanup " << CleanupTimer.Elapsed() <<std::endl;
359 std::cout << GridLogMessage << "\tSolver+Cleanup " << SolverTimer.Elapsed() + CleanupTimer.Elapsed() << std::endl;
360
362
363 return;
364 }
365
366 }
367 std::cout<<GridLogMessage<<"CG multi shift did not converge"<<std::endl;
368 assert(0);
369 }
370
371};
373
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpby(Lattice< vobj > &ret, sobj a, sobj b, const Lattice< vobj > &x, const Lattice< vobj > &y)
Lattice< vobj > real(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
void precisionChange(Lattice< VobjOut > &out, const Lattice< VobjIn > &in, const precisionChangeWorkspace &workspace)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
double RealD
Definition Simd.h:61
#define GRID_TRACE(name)
Definition Tracing.h:68
void operator()(LinearOperatorBase< FieldD > &Linop, const FieldD &src, FieldD &psi)
ConjugateGradientMultiShiftMixedPrecCleanup(Integer maxit, const MultiShiftFunction &_shifts, GridBase *_SinglePrecGrid, LinearOperatorBase< FieldF > &_Linop_f, int _ReliableUpdateFreq)
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
virtual void HermOpAndNorm(const Field &in, Field &out, RealD &n1, RealD &n2)=0
Definition Simd.h:194