Grid 0.7.0
ConjugateGradient.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/ConjugateGradient.h
6
7Copyright (C) 2015
8
9Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
10Author: Peter Boyle <paboyle@ph.ed.ac.uk>
11Author: paboyle <paboyle@ph.ed.ac.uk>
12
13This program is free software; you can redistribute it and/or modify
14it under the terms of the GNU General Public License as published by
15the Free Software Foundation; either version 2 of the License, or
16(at your option) any later version.
17
18This program is distributed in the hope that it will be useful,
19but WITHOUT ANY WARRANTY; without even the implied warranty of
20MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21GNU General Public License for more details.
22
23You should have received a copy of the GNU General Public License along
24with this program; if not, write to the Free Software Foundation, Inc.,
2551 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26
27See the full license in the file "LICENSE" in the top level distribution
28directory
29*************************************************************************************/
30 /* END LEGAL */
31#ifndef GRID_CONJUGATE_GRADIENT_H
32#define GRID_CONJUGATE_GRADIENT_H
33
35
37// Base classes for iterative processes based on operators
38// single input vec, single output vec.
40
41
42template <class Field>
43class ConjugateGradient : public OperatorFunction<Field> {
44public:
45
46 using OperatorFunction<Field>::operator();
47
48 bool ErrorOnNoConverge; // throw an assert when the CG fails to converge.
49 // Defaults true.
52 Integer IterationsToComplete; //Number of iterations the CG took to finish. Filled in upon completion
54
55 ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv = true)
56 : Tolerance(tol),
57 MaxIterations(maxit),
58 ErrorOnNoConverge(err_on_no_conv)
59 {};
60
61 virtual void LogIteration(int k,RealD a,RealD b){
62 // std::cout << "ConjugageGradient::LogIteration() "<<std::endl;
63 };
64 virtual void LogBegin(void){
65 std::cout << "ConjugageGradient::LogBegin() "<<std::endl;
66 };
67
68 void operator()(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi) {
69
70 this->LogBegin();
71
72 GRID_TRACE("ConjugateGradient");
73 GridStopWatch PreambleTimer;
74 GridStopWatch ConstructTimer;
75 GridStopWatch NormTimer;
76 GridStopWatch AssignTimer;
77 PreambleTimer.Start();
78 psi.Checkerboard() = src.Checkerboard();
79
80 conformable(psi, src);
81
82 RealD cp, c, a, d, b, ssq, qq;
83 //RealD b_pred;
84
85 // Was doing copies
86 ConstructTimer.Start();
87 Field p (src.Grid());
88 Field mmp(src.Grid());
89 Field r (src.Grid());
90 ConstructTimer.Stop();
91
92 // Initial residual computation & set up
93 NormTimer.Start();
94 ssq = norm2(src);
95 RealD guess = norm2(psi);
96 NormTimer.Stop();
97 assert(std::isnan(guess) == 0);
98 AssignTimer.Start();
99 if ( guess == 0.0 ) {
100 r = src;
101 p = r;
102 a = ssq;
103 } else {
104 Linop.HermOpAndNorm(psi, mmp, d, b);
105 r = src - mmp;
106 p = r;
107 a = norm2(p);
108 }
109 cp = a;
110 AssignTimer.Stop();
111
112 // Handle trivial case of zero src
113 if (ssq == 0.){
114 psi = Zero();
116 TrueResidual = 0.;
117 return;
118 }
119
120 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: guess " << guess << std::endl;
121 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: src " << ssq << std::endl;
122 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: mp " << d << std::endl;
123 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: mmp " << b << std::endl;
124 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: cp,r " << cp << std::endl;
125 std::cout << GridLogIterative << std::setprecision(8) << "ConjugateGradient: p " << a << std::endl;
126
127 RealD rsq = Tolerance * Tolerance * ssq;
128
129 // Check if guess is really REALLY good :)
130 if (cp <= rsq) {
131 TrueResidual = std::sqrt(a/ssq);
132 std::cout << GridLogMessage << "ConjugateGradient guess is converged already " << std::endl;
134 return;
135 }
136
137 std::cout << GridLogIterative << std::setprecision(8)
138 << "ConjugateGradient: k=0 residual " << cp << " target " << rsq << std::endl;
139
140 PreambleTimer.Stop();
141 GridStopWatch LinalgTimer;
142 GridStopWatch InnerTimer;
143 GridStopWatch AxpyNormTimer;
144 GridStopWatch LinearCombTimer;
145 GridStopWatch MatrixTimer;
146 GridStopWatch SolverTimer;
147
148 RealD usecs = -usecond();
149 SolverTimer.Start();
150 int k;
151 for (k = 1; k <= MaxIterations; k++) {
152
153 GridStopWatch IterationTimer;
154 IterationTimer.Start();
155 c = cp;
156
157 MatrixTimer.Start();
158 Linop.HermOp(p, mmp);
159 MatrixTimer.Stop();
160
161 LinalgTimer.Start();
162
163 InnerTimer.Start();
164 ComplexD dc = innerProduct(p,mmp);
165 InnerTimer.Stop();
166 d = dc.real();
167 a = c / d;
168
169 AxpyNormTimer.Start();
170 cp = axpy_norm(r, -a, mmp, r);
171 AxpyNormTimer.Stop();
172 b = cp / c;
173
174 LinearCombTimer.Start();
175 {
176 autoView( psi_v , psi, AcceleratorWrite);
177 autoView( p_v , p, AcceleratorWrite);
178 autoView( r_v , r, AcceleratorWrite);
179 accelerator_for(ss,p_v.size(), Field::vector_object::Nsimd(),{
180 coalescedWrite(psi_v[ss], a * p_v(ss) + psi_v(ss));
181 coalescedWrite(p_v[ss] , b * p_v(ss) + r_v (ss));
182 });
183 }
184 LinearCombTimer.Stop();
185 LinalgTimer.Stop();
186 LogIteration(k,a,b);
187
188 IterationTimer.Stop();
189 if ( (k % 500) == 0 ) {
190 std::cout << GridLogMessage << "ConjugateGradient: Iteration " << k
191 << " residual " << sqrt(cp/ssq) << " target " << Tolerance << std::endl;
192 } else {
193 std::cout << GridLogIterative << "ConjugateGradient: Iteration " << k
194 << " residual " << sqrt(cp/ssq) << " target " << Tolerance << " took " << IterationTimer.Elapsed() << std::endl;
195 }
196
197 // Stopping condition
198 if (cp <= rsq) {
199 usecs +=usecond();
200 SolverTimer.Stop();
201 Linop.HermOpAndNorm(psi, mmp, d, qq);
202 p = mmp - src;
203 GridBase *grid = src.Grid();
204 RealD DwfFlops = (1452. )*grid->gSites()*4*k
205 + (8+4+8+4+4)*12*grid->gSites()*k; // CG linear algebra
206 RealD srcnorm = std::sqrt(norm2(src));
207 RealD resnorm = std::sqrt(norm2(p));
208 RealD true_residual = resnorm / srcnorm;
209 std::cout << GridLogMessage << "ConjugateGradient Converged on iteration " << k
210 << "\tComputed residual " << std::sqrt(cp / ssq)
211 << "\tTrue residual " << true_residual
212 << "\tTarget " << Tolerance << std::endl;
213
214 // std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
215 std::cout << GridLogMessage << "\tSolver Elapsed " << SolverTimer.Elapsed() <<std::endl;
216 std::cout << GridLogPerformance << "Time breakdown "<<std::endl;
217 std::cout << GridLogPerformance << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
218 std::cout << GridLogPerformance << "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
219 std::cout << GridLogPerformance << "\t\tInner " << InnerTimer.Elapsed() <<std::endl;
220 std::cout << GridLogPerformance << "\t\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
221 std::cout << GridLogPerformance << "\t\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
222
223 std::cout << GridLogDebug << "\tMobius flop rate " << DwfFlops/ usecs<< " Gflops " <<std::endl;
224
225 if (ErrorOnNoConverge) assert(true_residual / Tolerance < 10000.0);
226
228 TrueResidual = true_residual;
229
230 return;
231 }
232 }
233 // Failed. Calculate true residual before giving up
234 // Linop.HermOpAndNorm(psi, mmp, d, qq);
235 // p = mmp - src;
236 //TrueResidual = sqrt(norm2(p)/ssq);
237 // TrueResidual = 1;
238
239 std::cout << GridLogMessage << "ConjugateGradient did NOT converge "<<k<<" / "<< MaxIterations
240 <<" residual "<< std::sqrt(cp / ssq)<< std::endl;
241 SolverTimer.Stop();
242 std::cout << GridLogMessage << "\tPreamble " << PreambleTimer.Elapsed() <<std::endl;
243 std::cout << GridLogMessage << "\tConstruct " << ConstructTimer.Elapsed() <<std::endl;
244 std::cout << GridLogMessage << "\tNorm " << NormTimer.Elapsed() <<std::endl;
245 std::cout << GridLogMessage << "\tAssign " << AssignTimer.Elapsed() <<std::endl;
246 std::cout << GridLogMessage << "\tSolver " << SolverTimer.Elapsed() <<std::endl;
247 std::cout << GridLogMessage << "Solver breakdown "<<std::endl;
248 std::cout << GridLogMessage << "\tMatrix " << MatrixTimer.Elapsed() <<std::endl;
249 std::cout << GridLogMessage<< "\tLinalg " << LinalgTimer.Elapsed() <<std::endl;
250 std::cout << GridLogPerformance << "\t\tInner " << InnerTimer.Elapsed() <<std::endl;
251 std::cout << GridLogPerformance << "\t\tAxpyNorm " << AxpyNormTimer.Elapsed() <<std::endl;
252 std::cout << GridLogPerformance << "\t\tLinearComb " << LinearCombTimer.Elapsed() <<std::endl;
253
254 if (ErrorOnNoConverge) assert(0);
256
257 }
258};
259
260
261template <class Field>
263public:
264 // Optionally record the CG polynomial
265 std::vector<double> ak;
266 std::vector<double> bk;
267 std::vector<double> poly_p;
268 std::vector<double> poly_r;
269 std::vector<double> poly_Ap;
270 std::vector<double> polynomial;
271
272public:
273 ConjugateGradientPolynomial(RealD tol, Integer maxit, bool err_on_no_conv = true)
274 : ConjugateGradient<Field>(tol,maxit,err_on_no_conv)
275 { };
276 void PolyHermOp(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
277 {
278 Field tmp(src.Grid());
279 Field AtoN(src.Grid());
280 AtoN = src;
281 psi=AtoN*polynomial[0];
282 for(int n=1;n<polynomial.size();n++){
283 tmp = AtoN;
284 Linop.HermOp(tmp,AtoN);
285 psi = psi + polynomial[n]*AtoN;
286 }
287 }
288 void CGsequenceHermOp(LinearOperatorBase<Field> &Linop, const Field &src, Field &x)
289 {
290 Field Ap(src.Grid());
291 Field r(src.Grid());
292 Field p(src.Grid());
293 p=src;
294 r=src;
295 x=Zero();
296 x.Checkerboard()=src.Checkerboard();
297 for(int k=0;k<ak.size();k++){
298 x = x + ak[k]*p;
299 Linop.HermOp(p,Ap);
300 r = r - ak[k] * Ap;
301 p = r + bk[k] * p;
302 }
303 }
304 void Solve(LinearOperatorBase<Field> &Linop, const Field &src, Field &psi)
305 {
306 psi=Zero();
307 this->operator ()(Linop,src,psi);
308 }
309 virtual void LogBegin(void)
310 {
311 std::cout << "ConjugageGradientPolynomial::LogBegin() "<<std::endl;
312 ak.resize(0);
313 bk.resize(0);
314 polynomial.resize(0);
315 poly_Ap.resize(0);
316 poly_Ap.resize(0);
317 poly_p.resize(1);
318 poly_r.resize(1);
319 poly_p[0]=1.0;
320 poly_r[0]=1.0;
321 };
322 virtual void LogIteration(int k,RealD a,RealD b)
323 {
324 // With zero guess,
325 // p = r = src
326 //
327 // iterate:
328 // x = x + a p
329 // r = r - a A p
330 // p = r + b p
331 //
332 // [0]
333 // r = x
334 // p = x
335 // Ap=0
336 //
337 // [1]
338 // Ap = A x + 0 ==> shift poly P right by 1 and add 0.
339 // x = x + a p ==> add polynomials term by term
340 // r = r - a A p ==> add polynomials term by term
341 // p = r + b p ==> add polynomials term by term
342 //
343 std::cout << "ConjugageGradientPolynomial::LogIteration() "<<k<<std::endl;
344 ak.push_back(a);
345 bk.push_back(b);
346 // Ap= right_shift(p)
347 poly_Ap.resize(k+1);
348 poly_Ap[0]=0.0;
349 for(int i=0;i<k;i++){
350 poly_Ap[i+1]=poly_p[i];
351 }
352
353 // x = x + a p
354 polynomial.resize(k);
355 polynomial[k-1]=0.0;
356 for(int i=0;i<k;i++){
357 polynomial[i] = polynomial[i] + a * poly_p[i];
358 }
359
360 // r = r - a Ap
361 // p = r + b p
362 poly_r.resize(k+1);
363 poly_p.resize(k+1);
364 poly_r[k] = poly_p[k] = 0.0;
365 for(int i=0;i<k+1;i++){
366 poly_r[i] = poly_r[i] - a * poly_Ap[i];
367 poly_p[i] = poly_r[i] + b * poly_p[i];
368 }
369 }
370};
371
373#endif
#define accelerator_for(iterator, num, nsimd,...)
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
#define autoView(l_v, l, mode)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogPerformance(1, "Performance", GridLogColours, "GREEN")
GridLogger GridLogDebug(1, "Debug", GridLogColours, "PURPLE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
@ AcceleratorWrite
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
double usecond(void)
Definition Timer.h:50
#define GRID_TRACE(name)
Definition Tracing.h:68
std::vector< double > poly_r
void Solve(LinearOperatorBase< Field > &Linop, const Field &src, Field &psi)
ConjugateGradientPolynomial(RealD tol, Integer maxit, bool err_on_no_conv=true)
std::vector< double > poly_p
void CGsequenceHermOp(LinearOperatorBase< Field > &Linop, const Field &src, Field &x)
void PolyHermOp(LinearOperatorBase< Field > &Linop, const Field &src, Field &psi)
virtual void LogIteration(int k, RealD a, RealD b)
std::vector< double > poly_Ap
std::vector< double > polynomial
virtual void LogBegin(void)
void operator()(LinearOperatorBase< Field > &Linop, const Field &src, Field &psi)
ConjugateGradient(RealD tol, Integer maxit, bool err_on_no_conv=true)
virtual void LogIteration(int k, RealD a, RealD b)
int64_t gSites(void) const
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
virtual void HermOp(const Field &in, Field &out)=0
virtual void HermOpAndNorm(const Field &in, Field &out, RealD &n1, RealD &n2)=0
Definition Simd.h:194