Grid 0.7.0
FlexibleCommunicationAvoidingGeneralisedMinimalResidual.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/FlexibleCommunicationAvoidingGeneralisedMinimalResidual.h
6
7Copyright (C) 2015
8
9Author: Daniel Richtmann <daniel.richtmann@ur.de>
10
11This program is free software; you can redistribute it and/or modify
12it under the terms of the GNU General Public License as published by
13the Free Software Foundation; either version 2 of the License, or
14(at your option) any later version.
15
16This program is distributed in the hope that it will be useful,
17but WITHOUT ANY WARRANTY; without even the implied warranty of
18MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19GNU General Public License for more details.
20
21You should have received a copy of the GNU General Public License along
22with this program; if not, write to the Free Software Foundation, Inc.,
2351 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25See the full license in the file "LICENSE" in the top level distribution
26directory
27*************************************************************************************/
28/* END LEGAL */
29#ifndef GRID_FLEXIBLE_COMMUNICATION_AVOIDING_GENERALISED_MINIMAL_RESIDUAL_H
30#define GRID_FLEXIBLE_COMMUNICATION_AVOIDING_GENERALISED_MINIMAL_RESIDUAL_H
31
32namespace Grid {
33
34template<class Field>
36 public:
37 using OperatorFunction<Field>::operator();
38
39 bool ErrorOnNoConverge; // Throw an assert when FCAGMRES fails to converge,
40 // defaults to true
41
43
47 Integer IterationCount; // Number of iterations the FCAGMRES took to finish,
48 // filled in upon completion
49
55
56 Eigen::MatrixXcd H;
57
58 std::vector<ComplexD> y;
59 std::vector<ComplexD> gamma;
60 std::vector<ComplexD> c;
61 std::vector<ComplexD> s;
62
64
66 Integer maxit,
68 Integer restart_length,
69 bool err_on_no_conv = true)
70 : Tolerance(tol)
71 , MaxIterations(maxit)
72 , RestartLength(restart_length)
74 , ErrorOnNoConverge(err_on_no_conv)
75 , H(Eigen::MatrixXcd::Zero(RestartLength, RestartLength + 1)) // sizes taken from DD-αAMG code base
76 , y(RestartLength + 1, 0.)
77 , gamma(RestartLength + 1, 0.)
78 , c(RestartLength + 1, 0.)
79 , s(RestartLength + 1, 0.)
80 , Preconditioner(Prec) {};
81
82 void operator()(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi) {
83
84 std::cout << GridLogWarning << "This algorithm currently doesn't differ from regular FGMRES" << std::endl;
85
86 psi.Checkerboard() = src.Checkerboard();
87 conformable(psi, src);
88
89 RealD guess = norm2(psi);
90 assert(std::isnan(guess) == 0);
91
92 RealD cp;
93 RealD ssq = norm2(src);
94 RealD rsq = Tolerance * Tolerance * ssq;
95
96 Field r(src.Grid());
97
98 std::cout << std::setprecision(4) << std::scientific;
99 std::cout << GridLogIterative << "FlexibleCommunicationAvoidingGeneralisedMinimalResidual: guess " << guess << std::endl;
100 std::cout << GridLogIterative << "FlexibleCommunicationAvoidingGeneralisedMinimalResidual: src " << ssq << std::endl;
101
102 PrecTimer.Reset();
103 MatrixTimer.Reset();
104 LinalgTimer.Reset();
105 QrTimer.Reset();
106 CompSolutionTimer.Reset();
107
108 GridStopWatch SolverTimer;
109 SolverTimer.Start();
110
111 IterationCount = 0;
112
113 for (int k=0; k<MaxNumberOfRestarts; k++) {
114
115 cp = outerLoopBody(LinOp, src, psi, rsq);
116
117 // Stopping condition
118 if (cp <= rsq) {
119
120 SolverTimer.Stop();
121
122 LinOp.Op(psi,r);
123 axpy(r,-1.0,src,r);
124
125 RealD srcnorm = sqrt(ssq);
126 RealD resnorm = sqrt(norm2(r));
127 RealD true_residual = resnorm / srcnorm;
128
129 std::cout << GridLogMessage << "FlexibleCommunicationAvoidingGeneralisedMinimalResidual: Converged on iteration " << IterationCount
130 << " computed residual " << sqrt(cp / ssq)
131 << " true residual " << true_residual
132 << " target " << Tolerance << std::endl;
133
134 std::cout << GridLogMessage << "FCAGMRES Time elapsed: Total " << SolverTimer.Elapsed() << std::endl;
135 std::cout << GridLogMessage << "FCAGMRES Time elapsed: Precon " << PrecTimer.Elapsed() << std::endl;
136 std::cout << GridLogMessage << "FCAGMRES Time elapsed: Matrix " << MatrixTimer.Elapsed() << std::endl;
137 std::cout << GridLogMessage << "FCAGMRES Time elapsed: Linalg " << LinalgTimer.Elapsed() << std::endl;
138 std::cout << GridLogMessage << "FCAGMRES Time elapsed: QR " << QrTimer.Elapsed() << std::endl;
139 std::cout << GridLogMessage << "FCAGMRES Time elapsed: CompSol " << CompSolutionTimer.Elapsed() << std::endl;
140 return;
141 }
142 }
143
144 std::cout << GridLogMessage << "FlexibleCommunicationAvoidingGeneralisedMinimalResidual did NOT converge" << std::endl;
145
147 assert(0);
148 }
149
150 RealD outerLoopBody(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi, RealD rsq) {
151
152 RealD cp = 0;
153
154 Field w(src.Grid());
155 Field r(src.Grid());
156
157 // these should probably be made class members so that they are only allocated once, not in every restart
158 std::vector<Field> v(RestartLength + 1, src.Grid()); for (auto &elem : v) elem = Zero();
159 std::vector<Field> z(RestartLength + 1, src.Grid()); for (auto &elem : z) elem = Zero();
160
161 MatrixTimer.Start();
162 LinOp.Op(psi, w);
163 MatrixTimer.Stop();
164
165 LinalgTimer.Start();
166 r = src - w;
167
168 gamma[0] = sqrt(norm2(r));
169
170 v[0] = (1. / gamma[0]) * r;
171 LinalgTimer.Stop();
172
173 for (int i=0; i<RestartLength; i++) {
174
176
177 arnoldiStep(LinOp, v, z, w, i);
178
179 qrUpdate(i);
180
181 cp = norm(gamma[i+1]);
182
183 std::cout << GridLogIterative << "FlexibleCommunicationAvoidingGeneralisedMinimalResidual: Iteration " << IterationCount
184 << " residual " << cp << " target " << rsq << std::endl;
185
186 if ((i == RestartLength - 1) || (IterationCount == MaxIterations) || (cp <= rsq)) {
187
188 computeSolution(z, psi, i);
189
190 return cp;
191 }
192 }
193
194 assert(0); // Never reached
195 return cp;
196 }
197
198 void arnoldiStep(LinearOperatorBase<Field> &LinOp, std::vector<Field> &v, std::vector<Field> &z, Field &w, int iter) {
199
200 PrecTimer.Start();
201 Preconditioner(v[iter], z[iter]);
202 PrecTimer.Stop();
203
204 MatrixTimer.Start();
205 LinOp.Op(z[iter], w);
206 MatrixTimer.Stop();
207
208 LinalgTimer.Start();
209 for (int i = 0; i <= iter; ++i) {
210 H(iter, i) = innerProduct(v[i], w);
211 w = w - ComplexD(H(iter, i)) * v[i];
212 }
213
214 H(iter, iter + 1) = sqrt(norm2(w));
215 v[iter + 1] = ComplexD(1. / H(iter, iter + 1)) * w;
216 LinalgTimer.Stop();
217 }
218
219 void qrUpdate(int iter) {
220
221 QrTimer.Start();
222 for (int i = 0; i < iter ; ++i) {
223 auto tmp = -s[i] * ComplexD(H(iter, i)) + c[i] * ComplexD(H(iter, i + 1));
224 H(iter, i) = conjugate(c[i]) * ComplexD(H(iter, i)) + conjugate(s[i]) * ComplexD(H(iter, i + 1));
225 H(iter, i + 1) = tmp;
226 }
227
228 // Compute new Givens Rotation
229 auto nu = sqrt(std::norm(H(iter, iter)) + std::norm(H(iter, iter + 1)));
230 c[iter] = H(iter, iter) / nu;
231 s[iter] = H(iter, iter + 1) / nu;
232
233 // Apply new Givens rotation
234 H(iter, iter) = nu;
235 H(iter, iter + 1) = 0.;
236
237 gamma[iter + 1] = -s[iter] * gamma[iter];
238 gamma[iter] = conjugate(c[iter]) * gamma[iter];
239 QrTimer.Stop();
240 }
241
242 void computeSolution(std::vector<Field> const &z, Field &psi, int iter) {
243
244 CompSolutionTimer.Start();
245 for (int i = iter; i >= 0; i--) {
246 y[i] = gamma[i];
247 for (int k = i + 1; k <= iter; k++)
248 y[i] = y[i] - ComplexD(H(k, i)) * y[k];
249 y[i] = y[i] / ComplexD(H(i, i));
250 }
251
252 for (int i = 0; i <= iter; i++)
253 psi = psi + z[i] * y[i];
254 CompSolutionTimer.Stop();
255 }
256};
257}
258#endif
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
Lattice< vobj > conjugate(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
GridLogger GridLogWarning(1, "Warning", GridLogColours, "YELLOW")
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
void operator()(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi)
FlexibleCommunicationAvoidingGeneralisedMinimalResidual(RealD tol, Integer maxit, LinearFunction< Field > &Prec, Integer restart_length, bool err_on_no_conv=true)
RealD outerLoopBody(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi, RealD rsq)
void arnoldiStep(LinearOperatorBase< Field > &LinOp, std::vector< Field > &v, std::vector< Field > &z, Field &w, int iter)
virtual void Op(const Field &in, Field &out)=0
Definition Simd.h:194