Grid 0.7.0
CommunicationAvoidingGeneralisedMinimalResidual.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/CommunicationAvoidingGeneralisedMinimalResidual.h
6
7Copyright (C) 2015
8
9Author: Daniel Richtmann <daniel.richtmann@ur.de>
10
11This program is free software; you can redistribute it and/or modify
12it under the terms of the GNU General Public License as published by
13the Free Software Foundation; either version 2 of the License, or
14(at your option) any later version.
15
16This program is distributed in the hope that it will be useful,
17but WITHOUT ANY WARRANTY; without even the implied warranty of
18MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19GNU General Public License for more details.
20
21You should have received a copy of the GNU General Public License along
22with this program; if not, write to the Free Software Foundation, Inc.,
2351 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25See the full license in the file "LICENSE" in the top level distribution
26directory
27*************************************************************************************/
28/* END LEGAL */
29#ifndef GRID_COMMUNICATION_AVOIDING_GENERALISED_MINIMAL_RESIDUAL_H
30#define GRID_COMMUNICATION_AVOIDING_GENERALISED_MINIMAL_RESIDUAL_H
31
32namespace Grid {
33
34template<class Field>
36 public:
37 using OperatorFunction<Field>::operator();
38
39 bool ErrorOnNoConverge; // Throw an assert when CAGMRES fails to converge,
40 // defaults to true
41
43
47 Integer IterationCount; // Number of iterations the CAGMRES took to finish,
48 // filled in upon completion
49
54
55 Eigen::MatrixXcd H;
56
57 std::vector<ComplexD> y;
58 std::vector<ComplexD> gamma;
59 std::vector<ComplexD> c;
60 std::vector<ComplexD> s;
61
63 Integer maxit,
64 Integer restart_length,
65 bool err_on_no_conv = true)
66 : Tolerance(tol)
67 , MaxIterations(maxit)
68 , RestartLength(restart_length)
70 , ErrorOnNoConverge(err_on_no_conv)
71 , H(Eigen::MatrixXcd::Zero(RestartLength, RestartLength + 1)) // sizes taken from DD-αAMG code base
72 , y(RestartLength + 1, 0.)
73 , gamma(RestartLength + 1, 0.)
74 , c(RestartLength + 1, 0.)
75 , s(RestartLength + 1, 0.) {};
76
77 void operator()(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi) {
78
79 std::cout << GridLogWarning << "This algorithm currently doesn't differ from regular GMRES" << std::endl;
80
81 psi.Checkerboard() = src.Checkerboard();
82 conformable(psi, src);
83
84 RealD guess = norm2(psi);
85 assert(std::isnan(guess) == 0);
86
87 RealD cp;
88 RealD ssq = norm2(src);
89 RealD rsq = Tolerance * Tolerance * ssq;
90
91 Field r(src.Grid());
92
93 std::cout << std::setprecision(4) << std::scientific;
94 std::cout << GridLogIterative << "CommunicationAvoidingGeneralisedMinimalResidual: guess " << guess << std::endl;
95 std::cout << GridLogIterative << "CommunicationAvoidingGeneralisedMinimalResidual: src " << ssq << std::endl;
96
97 MatrixTimer.Reset();
98 LinalgTimer.Reset();
99 QrTimer.Reset();
100 CompSolutionTimer.Reset();
101
102 GridStopWatch SolverTimer;
103 SolverTimer.Start();
104
105 IterationCount = 0;
106
107 for (int k=0; k<MaxNumberOfRestarts; k++) {
108
109 cp = outerLoopBody(LinOp, src, psi, rsq);
110
111 // Stopping condition
112 if (cp <= rsq) {
113
114 SolverTimer.Stop();
115
116 LinOp.Op(psi,r);
117 axpy(r,-1.0,src,r);
118
119 RealD srcnorm = sqrt(ssq);
120 RealD resnorm = sqrt(norm2(r));
121 RealD true_residual = resnorm / srcnorm;
122
123 std::cout << GridLogMessage << "CommunicationAvoidingGeneralisedMinimalResidual: Converged on iteration " << IterationCount
124 << " computed residual " << sqrt(cp / ssq)
125 << " true residual " << true_residual
126 << " target " << Tolerance << std::endl;
127
128 std::cout << GridLogMessage << "CAGMRES Time elapsed: Total " << SolverTimer.Elapsed() << std::endl;
129 std::cout << GridLogMessage << "CAGMRES Time elapsed: Matrix " << MatrixTimer.Elapsed() << std::endl;
130 std::cout << GridLogMessage << "CAGMRES Time elapsed: Linalg " << LinalgTimer.Elapsed() << std::endl;
131 std::cout << GridLogMessage << "CAGMRES Time elapsed: QR " << QrTimer.Elapsed() << std::endl;
132 std::cout << GridLogMessage << "CAGMRES Time elapsed: CompSol " << CompSolutionTimer.Elapsed() << std::endl;
133 return;
134 }
135 }
136
137 std::cout << GridLogMessage << "CommunicationAvoidingGeneralisedMinimalResidual did NOT converge" << std::endl;
138
140 assert(0);
141 }
142
143 RealD outerLoopBody(LinearOperatorBase<Field> &LinOp, const Field &src, Field &psi, RealD rsq) {
144
145 RealD cp = 0;
146
147 Field w(src.Grid());
148 Field r(src.Grid());
149
150 // this should probably be made a class member so that it is only allocated once, not in every restart
151 std::vector<Field> v(RestartLength + 1, src.Grid()); for (auto &elem : v) elem = Zero();
152
153 MatrixTimer.Start();
154 LinOp.Op(psi, w);
155 MatrixTimer.Stop();
156
157 LinalgTimer.Start();
158 r = src - w;
159
160 gamma[0] = sqrt(norm2(r));
161
162 ComplexD scale = 1.0/gamma[0];
163 v[0] = scale * r;
164
165 LinalgTimer.Stop();
166
167 for (int i=0; i<RestartLength; i++) {
168
170
171 arnoldiStep(LinOp, v, w, i);
172
173 qrUpdate(i);
174
175 cp = norm(gamma[i+1]);
176
177 std::cout << GridLogIterative << "CommunicationAvoidingGeneralisedMinimalResidual: Iteration " << IterationCount
178 << " residual " << cp << " target " << rsq << std::endl;
179
180 if ((i == RestartLength - 1) || (IterationCount == MaxIterations) || (cp <= rsq)) {
181
182 computeSolution(v, psi, i);
183
184 return cp;
185 }
186 }
187
188 assert(0); // Never reached
189 return cp;
190 }
191
192 void arnoldiStep(LinearOperatorBase<Field> &LinOp, std::vector<Field> &v, Field &w, int iter) {
193
194 MatrixTimer.Start();
195 LinOp.Op(v[iter], w);
196 MatrixTimer.Stop();
197
198 LinalgTimer.Start();
199 for (int i = 0; i <= iter; ++i) {
200 H(iter, i) = innerProduct(v[i], w);
201 w = w - ComplexD(H(iter, i)) * v[i];
202 }
203
204 H(iter, iter + 1) = sqrt(norm2(w));
205 v[iter + 1] = ComplexD(1. / H(iter, iter + 1)) * w;
206 LinalgTimer.Stop();
207 }
208
209 void qrUpdate(int iter) {
210
211 QrTimer.Start();
212 for (int i = 0; i < iter ; ++i) {
213 auto tmp = -s[i] * ComplexD(H(iter, i)) + c[i] * ComplexD(H(iter, i + 1));
214 H(iter, i) = conjugate(c[i]) * ComplexD(H(iter, i)) + conjugate(s[i]) * ComplexD(H(iter, i + 1));
215 H(iter, i + 1) = tmp;
216 }
217
218 // Compute new Givens Rotation
219 auto nu = sqrt(std::norm(H(iter, iter)) + std::norm(H(iter, iter + 1)));
220 c[iter] = H(iter, iter) / nu;
221 s[iter] = H(iter, iter + 1) / nu;
222
223 // Apply new Givens rotation
224 H(iter, iter) = nu;
225 H(iter, iter + 1) = 0.;
226
227 gamma[iter + 1] = -s[iter] * gamma[iter];
228 gamma[iter] = conjugate(c[iter]) * gamma[iter];
229 QrTimer.Stop();
230 }
231
232 void computeSolution(std::vector<Field> const &v, Field &psi, int iter) {
233
234 CompSolutionTimer.Start();
235 for (int i = iter; i >= 0; i--) {
236 y[i] = gamma[i];
237 for (int k = i + 1; k <= iter; k++)
238 y[i] = y[i] - ComplexD(H(k, i)) * y[k];
239 y[i] = y[i] / ComplexD(H(i, i));
240 }
241
242 for (int i = 0; i <= iter; i++)
243 psi = psi + v[i] * y[i];
244 CompSolutionTimer.Stop();
245 }
246};
247}
248#endif
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
Lattice< vobj > conjugate(const Lattice< vobj > &lhs)
ComplexD innerProduct(const Lattice< vobj > &left, const Lattice< vobj > &right)
RealD norm2(const Lattice< vobj > &arg)
GridLogger GridLogIterative(1, "Iterative", GridLogColours, "BLUE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
GridLogger GridLogWarning(1, "Warning", GridLogColours, "YELLOW")
uint32_t Integer
Definition Simd.h:58
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
void arnoldiStep(LinearOperatorBase< Field > &LinOp, std::vector< Field > &v, Field &w, int iter)
RealD outerLoopBody(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi, RealD rsq)
void computeSolution(std::vector< Field > const &v, Field &psi, int iter)
void operator()(LinearOperatorBase< Field > &LinOp, const Field &src, Field &psi)
CommunicationAvoidingGeneralisedMinimalResidual(RealD tol, Integer maxit, Integer restart_length, bool err_on_no_conv=true)
virtual void Op(const Field &in, Field &out)=0
Definition Simd.h:194