Grid 0.7.0
GeneralisedMinimalResidual.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/GeneralisedMinimalResidual.h
6
7Copyright (C) 2015
8
9Author: Daniel Richtmann <daniel.richtmann@ur.de>
10
11This program is free software; you can redistribute it and/or modify
12it under the terms of the GNU General Public License as published by
13the Free Software Foundation; either version 2 of the License, or
14(at your option) any later version.
15
16This program is distributed in the hope that it will be useful,
17but WITHOUT ANY WARRANTY; without even the implied warranty of
18MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19GNU General Public License for more details.
20
21You should have received a copy of the GNU General Public License along
22with this program; if not, write to the Free Software Foundation, Inc.,
2351 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25See the full license in the file "LICENSE" in the top level distribution
26directory
27*************************************************************************************/
28/* END LEGAL */
29#ifndef GRID_GENERALISED_MINIMAL_RESIDUAL_H
30#define GRID_GENERALISED_MINIMAL_RESIDUAL_H
31
32namespace Grid {
33
34template<class Field>
36 public:
37 using OperatorFunction<Field>::operator();
38
39 bool ErrorOnNoConverge; // Throw an assert when GMRES fails to converge,
40 // defaults to true
41
43
47 Integer IterationCount; // Number of iterations the GMRES took to finish,
48 // filled in upon completion
49
54
55 Eigen::MatrixXcd H;
56
57 std::vector<ComplexD> y;
58 std::vector<ComplexD> gamma;
59 std::vector<ComplexD> c;
60 std::vector<ComplexD> s;
61
63 Integer maxit,
64 Integer restart_length,
65 bool err_on_no_conv = true)
66 : Tolerance(tol)
67 , MaxIterations(maxit)
68 , RestartLength(restart_length)
70 , ErrorOnNoConverge(err_on_no_conv)
71 , H(Eigen::MatrixXcd::Zero(RestartLength, RestartLength + 1)) // sizes taken from DD-αAMG code base
72 , y(RestartLength + 1, 0.)
73 , gamma(RestartLength + 1, 0.)
74 , c(RestartLength + 1, 0.)
75 , s(RestartLength + 1, 0.) {};
76
77 void operator()(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi) {
78
79 psi.Checkerboard() = src.Checkerboard();
80 conformable(psi, src);
81
82 RealD guess = norm2(psi);
83 assert(std::isnan(guess) == 0);
84
85 RealD cp;
86 RealD ssq = norm2(src);
87 RealD rsq = Tolerance * Tolerance * ssq;
88
89 Field r(src.Grid());
90
91 std::cout << std::setprecision(4) << std::scientific;
92 std::cout << GridLogIterative << "GeneralisedMinimalResidual: guess " << guess << std::endl;
93 std::cout << GridLogIterative << "GeneralisedMinimalResidual: src " << ssq << std::endl;
94
95 MatrixTimer.Reset();
96 LinalgTimer.Reset();
97 QrTimer.Reset();
98 CompSolutionTimer.Reset();
99
100 GridStopWatch SolverTimer;
101 SolverTimer.Start();
102
103 IterationCount = 0;
104
105 for (int k=0; k<MaxNumberOfRestarts; k++) {
106
107 cp = outerLoopBody(LinOp, src, psi, rsq);
108
109 // Stopping condition
110 if (cp <= rsq) {
111
112 SolverTimer.Stop();
113
114 LinOp.Op(psi,r);
115 axpy(r,-1.0,src,r);
116
117 RealD srcnorm = sqrt(ssq);
118 RealD resnorm = sqrt(norm2(r));
119 RealD true_residual = resnorm / srcnorm;
120
121 std::cout << GridLogMessage << "GeneralisedMinimalResidual: Converged on iteration " << IterationCount
122 << " computed residual " << sqrt(cp / ssq)
123 << " true residual " << true_residual
124 << " target " << Tolerance << std::endl;
125
126 std::cout << GridLogMessage << "GMRES Time elapsed: Total " << SolverTimer.Elapsed() << std::endl;
127 std::cout << GridLogMessage << "GMRES Time elapsed: Matrix " << MatrixTimer.Elapsed() << std::endl;
128 std::cout << GridLogMessage << "GMRES Time elapsed: Linalg " << LinalgTimer.Elapsed() << std::endl;
129 std::cout << GridLogMessage << "GMRES Time elapsed: QR " << QrTimer.Elapsed() << std::endl;
130 std::cout << GridLogMessage << "GMRES Time elapsed: CompSol " << CompSolutionTimer.Elapsed() << std::endl;
131 return;
132 }
133 }
134
135 std::cout << GridLogMessage << "GeneralisedMinimalResidual did NOT converge" << std::endl;
136
138 assert(0);
139 }
140
141 RealD outerLoopBody(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi, RealD rsq) {
142
143 RealD cp = 0;
144
145 Field w(src.Grid());
146 Field r(src.Grid());
147
148 // this should probably be made a class member so that it is only allocated once, not in every restart
149 std::vector<Field> v(RestartLength + 1, src.Grid()); for (auto &elem : v) elem = Zero();
150
151 MatrixTimer.Start();
152 LinOp.Op(psi, w);
153 MatrixTimer.Stop();
154
155 LinalgTimer.Start();
156 r = src - w;
157
158 gamma[0] = sqrt(norm2(r));
159
160 v[0] = (1. / gamma[0]) * r;
161 LinalgTimer.Stop();
162
163 for (int i=0; i<RestartLength; i++) {
164
166
167 arnoldiStep(LinOp, v, w, i);
168
169 qrUpdate(i);
170
171 cp = norm(gamma[i+1]);
172
173 std::cout << GridLogIterative << "GeneralisedMinimalResidual: Iteration " << IterationCount
174 << " residual " << cp << " target " << rsq << std::endl;
175
176 if ((i == RestartLength - 1) || (IterationCount == MaxIterations) || (cp <= rsq)) {
177
178 computeSolution(v, psi, i);
179
180 return cp;
181 }
182 }
183
184 assert(0); // Never reached
185 return cp;
186 }
187
188 void arnoldiStep(LinearOperatorBase<Field> &LinOp, std::vector<Field> &v, Field &w, int iter) {
189
190 MatrixTimer.Start();
191 LinOp.Op(v[iter], w);
192 MatrixTimer.Stop();
193
194 LinalgTimer.Start();
195 for (int i = 0; i <= iter; ++i) {
196 H(iter, i) = innerProduct(v[i], w);
197 w = w - ComplexD(H(iter, i)) * v[i];
198 }
199
200 H(iter, iter + 1) = sqrt(norm2(w));
201 v[iter + 1] = ComplexD(1. / H(iter, iter + 1)) * w;
202 LinalgTimer.Stop();
203 }
204
205 void qrUpdate(int iter) {
206
207 QrTimer.Start();
208 for (int i = 0; i < iter ; ++i) {
209 auto tmp = -s[i] * ComplexD(H(iter, i)) + c[i] * ComplexD(H(iter, i + 1));
210 H(iter, i) = conjugate(c[i]) * ComplexD(H(iter, i)) + conjugate(s[i]) * ComplexD(H(iter, i + 1));
211 H(iter, i + 1) = tmp;
212 }
213
214 // Compute new Givens Rotation
215 auto nu = sqrt(std::norm(H(iter, iter)) + std::norm(H(iter, iter + 1)));
216 c[iter] = H(iter, iter) / nu;
217 s[iter] = H(iter, iter + 1) / nu;
218
219 // Apply new Givens rotation
220 H(iter, iter) = nu;
221 H(iter, iter + 1) = 0.;
222
223 gamma[iter + 1] = -s[iter] * gamma[iter];
224 gamma[iter] = conjugate(c[iter]) * gamma[iter];
225 QrTimer.Stop();
226 }
227
228 void computeSolution(std::vector<Field> const &v, Field &psi, int iter) {
229
230 CompSolutionTimer.Start();
231 for (int i = iter; i >= 0; i--) {
232 y[i] = gamma[i];
233 for (int k = i + 1; k <= iter; k++)
234 y[i] = y[i] - ComplexD(H(k, i)) * y[k];
235 y[i] = y[i] / ComplexD(H(i, i));
236 }
237
238 for (int i = 0; i <= iter; i++)
239 psi = psi + v[i] * y[i];
240 CompSolutionTimer.Stop();
241 }
242};
243}
244#endif
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
Lattice< vobj > conjugate(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
GeneralisedMinimalResidual(RealD tol, Integer maxit, Integer restart_length, bool err_on_no_conv=true)
void arnoldiStep(LinearOperatorBase< Field > &LinOp, std::vector< Field > &v, Field &w, int iter)
void operator()(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi)
RealD outerLoopBody(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi, RealD rsq)
void computeSolution(std::vector< Field > const &v, Field &psi, int iter)
virtual void Op(const Field &in, Field &out)=0
Definition Simd.h:194