Grid 0.7.0
FlexibleGeneralisedMinimalResidual.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/FlexibleGeneralisedMinimalResidual.h
6
7Copyright (C) 2015
8
9Author: Daniel Richtmann <daniel.richtmann@ur.de>
10
11This program is free software; you can redistribute it and/or modify
12it under the terms of the GNU General Public License as published by
13the Free Software Foundation; either version 2 of the License, or
14(at your option) any later version.
15
16This program is distributed in the hope that it will be useful,
17but WITHOUT ANY WARRANTY; without even the implied warranty of
18MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19GNU General Public License for more details.
20
21You should have received a copy of the GNU General Public License along
22with this program; if not, write to the Free Software Foundation, Inc.,
2351 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25See the full license in the file "LICENSE" in the top level distribution
26directory
27*************************************************************************************/
28/* END LEGAL */
29#ifndef GRID_FLEXIBLE_GENERALISED_MINIMAL_RESIDUAL_H
30#define GRID_FLEXIBLE_GENERALISED_MINIMAL_RESIDUAL_H
31
32namespace Grid {
33
34template<class Field>
36 public:
37 using OperatorFunction<Field>::operator();
38
39 bool ErrorOnNoConverge; // Throw an assert when FGMRES fails to converge,
40 // defaults to true
41
43
47 Integer IterationCount; // Number of iterations the FGMRES took to finish,
48 // filled in upon completion
49
55
56 Eigen::MatrixXcd H;
57
58 std::vector<ComplexD> y;
59 std::vector<ComplexD> gamma;
60 std::vector<ComplexD> c;
61 std::vector<ComplexD> s;
62
64
66 Integer maxit,
68 Integer restart_length,
69 bool err_on_no_conv = true)
70 : Tolerance(tol)
71 , MaxIterations(maxit)
72 , RestartLength(restart_length)
74 , ErrorOnNoConverge(err_on_no_conv)
75 , H(Eigen::MatrixXcd::Zero(RestartLength, RestartLength + 1)) // sizes taken from DD-αAMG code base
76 , y(RestartLength + 1, 0.)
77 , gamma(RestartLength + 1, 0.)
78 , c(RestartLength + 1, 0.)
79 , s(RestartLength + 1, 0.)
80 , Preconditioner(Prec) {};
81
82 void operator()(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi) {
83
84 psi.Checkerboard() = src.Checkerboard();
85 conformable(psi, src);
86
87 RealD guess = norm2(psi);
88 assert(std::isnan(guess) == 0);
89
90 RealD cp;
91 RealD ssq = norm2(src);
92 RealD rsq = Tolerance * Tolerance * ssq;
93
94 Field r(src.Grid());
95
96 std::cout << std::setprecision(4) << std::scientific;
97 std::cout << GridLogIterative << "FlexibleGeneralisedMinimalResidual: guess " << guess << std::endl;
98 std::cout << GridLogIterative << "FlexibleGeneralisedMinimalResidual: src " << ssq << std::endl;
99
100 PrecTimer.Reset();
101 MatrixTimer.Reset();
102 LinalgTimer.Reset();
103 QrTimer.Reset();
104 CompSolutionTimer.Reset();
105
106 GridStopWatch SolverTimer;
107 SolverTimer.Start();
108
109 IterationCount = 0;
110
111 for (int k=0; k<MaxNumberOfRestarts; k++) {
112
113 cp = outerLoopBody(LinOp, src, psi, rsq);
114
115 // Stopping condition
116 if (cp <= rsq) {
117
118 SolverTimer.Stop();
119
120 LinOp.Op(psi,r);
121 axpy(r,-1.0,src,r);
122
123 RealD srcnorm = sqrt(ssq);
124 RealD resnorm = sqrt(norm2(r));
125 RealD true_residual = resnorm / srcnorm;
126
127 std::cout << GridLogMessage << "FlexibleGeneralisedMinimalResidual: Converged on iteration " << IterationCount
128 << " computed residual " << sqrt(cp / ssq)
129 << " true residual " << true_residual
130 << " target " << Tolerance << std::endl;
131
132 std::cout << GridLogMessage << "FGMRES Time elapsed: Total " << SolverTimer.Elapsed() << std::endl;
133 std::cout << GridLogMessage << "FGMRES Time elapsed: Precon " << PrecTimer.Elapsed() << std::endl;
134 std::cout << GridLogMessage << "FGMRES Time elapsed: Matrix " << MatrixTimer.Elapsed() << std::endl;
135 std::cout << GridLogMessage << "FGMRES Time elapsed: Linalg " << LinalgTimer.Elapsed() << std::endl;
136 std::cout << GridLogMessage << "FGMRES Time elapsed: QR " << QrTimer.Elapsed() << std::endl;
137 std::cout << GridLogMessage << "FGMRES Time elapsed: CompSol " << CompSolutionTimer.Elapsed() << std::endl;
138 return;
139 }
140 }
141
142 std::cout << GridLogMessage << "FlexibleGeneralisedMinimalResidual did NOT converge" << std::endl;
143
145 assert(0);
146 }
147
148 RealD outerLoopBody(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi, RealD rsq) {
149
150 RealD cp = 0;
151
152 Field w(src.Grid());
153 Field r(src.Grid());
154
155 // these should probably be made class members so that they are only allocated once, not in every restart
156 std::vector<Field> v(RestartLength + 1, src.Grid()); for (auto &elem : v) elem = Zero();
157 std::vector<Field> z(RestartLength + 1, src.Grid()); for (auto &elem : z) elem = Zero();
158
159 MatrixTimer.Start();
160 LinOp.Op(psi, w);
161 MatrixTimer.Stop();
162
163 LinalgTimer.Start();
164 r = src - w;
165
166 gamma[0] = sqrt(norm2(r));
167
168 v[0] = (1. / gamma[0]) * r;
169 LinalgTimer.Stop();
170
171 for (int i=0; i<RestartLength; i++) {
172
174
175 arnoldiStep(LinOp, v, z, w, i);
176
177 qrUpdate(i);
178
179 cp = norm(gamma[i+1]);
180
181 std::cout << GridLogIterative << "FlexibleGeneralisedMinimalResidual: Iteration " << IterationCount
182 << " residual " << cp << " target " << rsq << std::endl;
183
184 if ((i == RestartLength - 1) || (IterationCount == MaxIterations) || (cp <= rsq)) {
185
186 computeSolution(z, psi, i);
187
188 return cp;
189 }
190 }
191
192 assert(0); // Never reached
193 return cp;
194 }
195
196 void arnoldiStep(LinearOperatorBase<Field> &LinOp, std::vector<Field> &v, std::vector<Field> &z, Field &w, int iter) {
197
198 PrecTimer.Start();
199 Preconditioner(v[iter], z[iter]);
200 PrecTimer.Stop();
201
202 MatrixTimer.Start();
203 LinOp.Op(z[iter], w);
204 MatrixTimer.Stop();
205
206 LinalgTimer.Start();
207 for (int i = 0; i <= iter; ++i) {
208 H(iter, i) = innerProduct(v[i], w);
209 w = w - ComplexD(H(iter, i)) * v[i];
210 }
211
212 H(iter, iter + 1) = sqrt(norm2(w));
213 v[iter + 1] = ComplexD(1. / H(iter, iter + 1)) * w;
214 LinalgTimer.Stop();
215 }
216
217 void qrUpdate(int iter) {
218
219 QrTimer.Start();
220 for (int i = 0; i < iter ; ++i) {
221 auto tmp = -s[i] * ComplexD(H(iter, i)) + c[i] * ComplexD(H(iter, i + 1));
222 H(iter, i) = conjugate(c[i]) * ComplexD(H(iter, i)) + conjugate(s[i]) * ComplexD(H(iter, i + 1));
223 H(iter, i + 1) = tmp;
224 }
225
226 // Compute new Givens Rotation
227 auto nu = sqrt(std::norm(H(iter, iter)) + std::norm(H(iter, iter + 1)));
228 c[iter] = H(iter, iter) / nu;
229 s[iter] = H(iter, iter + 1) / nu;
230
231 // Apply new Givens rotation
232 H(iter, iter) = nu;
233 H(iter, iter + 1) = 0.;
234
235 gamma[iter + 1] = -s[iter] * gamma[iter];
236 gamma[iter] = conjugate(c[iter]) * gamma[iter];
237 QrTimer.Stop();
238 }
239
240 void computeSolution(std::vector<Field> const &z, Field &psi, int iter) {
241
242 CompSolutionTimer.Start();
243 for (int i = iter; i >= 0; i--) {
244 y[i] = gamma[i];
245 for (int k = i + 1; k <= iter; k++)
246 y[i] = y[i] - ComplexD(H(k, i)) * y[k];
247 y[i] = y[i] / ComplexD(H(i, i));
248 }
249
250 for (int i = 0; i <= iter; i++)
251 psi = psi + z[i] * y[i];
252 CompSolutionTimer.Stop();
253 }
254};
255}
256#endif
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
Lattice< vobj > conjugate(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
void operator()(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi)
void computeSolution(std::vector< Field > const &z, Field &psi, int iter)
FlexibleGeneralisedMinimalResidual(RealD tol, Integer maxit, LinearFunction< Field > &Prec, Integer restart_length, bool err_on_no_conv=true)
void arnoldiStep(LinearOperatorBase< Field > &LinOp, std::vector< Field > &v, std::vector< Field > &z, Field &w, int iter)
RealD outerLoopBody(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi, RealD rsq)
virtual void Op(const Field &in, Field &out)=0
Definition Simd.h:194