Grid 0.7.0
ConjugateGradientMixedPrecBatched.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/ConjugateGradientMixedPrecBatched.h
6
7 Copyright (C) 2015
8
9 Author: Raoul Hodgson <raoul.hodgson@ed.ac.uk>
10
11 This program is free software; you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation; either version 2 of the License, or
14 (at your option) any later version.
15
16 This program is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
20
21 You should have received a copy of the GNU General Public License along
22 with this program; if not, write to the Free Software Foundation, Inc.,
23 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25 See the full license in the file "LICENSE" in the top level distribution directory
26*************************************************************************************/
27/* END LEGAL */
28#ifndef GRID_CONJUGATE_GRADIENT_MIXED_PREC_BATCHED_H
29#define GRID_CONJUGATE_GRADIENT_MIXED_PREC_BATCHED_H
30
32
33//Mixed precision restarted defect correction CG
34template<class FieldD,class FieldF,
35 typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0,
36 typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
38public:
39 using LinearFunction<FieldD>::operator();
41 RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed
45 GridBase* SinglePrecGrid; //Grid for single-precision fields
46 RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance
49
50 //Option to speed up *inner single precision* solves using a LinearFunction that produces a guess
53
55 Integer maxinnerit,
56 Integer maxouterit,
57 Integer maxpatchit,
58 GridBase* _sp_grid,
61 bool _updateResidual=true) :
62 Linop_f(_Linop_f), Linop_d(_Linop_d),
63 Tolerance(tol), InnerTolerance(tol), MaxInnerIterations(maxinnerit), MaxOuterIterations(maxouterit), MaxPatchupIterations(maxpatchit), SinglePrecGrid(_sp_grid),
64 OuterLoopNormMult(100.), guesser(NULL), updateResidual(_updateResidual) { };
65
67 guesser = &g;
68 }
69
70 void operator() (const FieldD &src_d_in, FieldD &sol_d){
71 std::vector<FieldD> srcs_d_in{src_d_in};
72 std::vector<FieldD> sols_d{sol_d};
73
74 (*this)(srcs_d_in,sols_d);
75
76 sol_d = sols_d[0];
77 }
78
79 void operator() (const std::vector<FieldD> &src_d_in, std::vector<FieldD> &sol_d){
80 assert(src_d_in.size() == sol_d.size());
81 int NBatch = src_d_in.size();
82
83 std::cout << GridLogMessage << "NBatch = " << NBatch << std::endl;
84
85 Integer TotalOuterIterations = 0; //Number of restarts
86 std::vector<Integer> TotalInnerIterations(NBatch,0); //Number of inner CG iterations
87 std::vector<Integer> TotalFinalStepIterations(NBatch,0); //Number of CG iterations in final patch-up step
88
89 GridStopWatch TotalTimer;
90 TotalTimer.Start();
91
92 GridStopWatch InnerCGtimer;
93 GridStopWatch PrecChangeTimer;
94
95 int cb = src_d_in[0].Checkerboard();
96
97 std::vector<RealD> src_norm;
98 std::vector<RealD> norm;
99 std::vector<RealD> stop;
100
101 GridBase* DoublePrecGrid = src_d_in[0].Grid();
102 FieldD tmp_d(DoublePrecGrid);
103 tmp_d.Checkerboard() = cb;
104
105 FieldD tmp2_d(DoublePrecGrid);
106 tmp2_d.Checkerboard() = cb;
107
108 std::vector<FieldD> src_d;
109 std::vector<FieldF> src_f;
110 std::vector<FieldF> sol_f;
111
112 for (int i=0; i<NBatch; i++) {
113 sol_d[i].Checkerboard() = cb;
114
115 src_norm.push_back(norm2(src_d_in[i]));
116 norm.push_back(0.);
117 stop.push_back(src_norm[i] * Tolerance*Tolerance);
118
119 src_d.push_back(src_d_in[i]); //source for next inner iteration, computed from residual during operation
120
121 src_f.push_back(SinglePrecGrid);
122 src_f[i].Checkerboard() = cb;
123
124 sol_f.push_back(SinglePrecGrid);
125 sol_f[i].Checkerboard() = cb;
126 }
127
128 RealD inner_tol = InnerTolerance;
129
131 CG_f.ErrorOnNoConverge = false;
132
133 Integer &outer_iter = TotalOuterIterations; //so it will be equal to the final iteration count
134
135 for(outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++){
136 std::cout << GridLogMessage << std::endl;
137 std::cout << GridLogMessage << "Outer iteration " << outer_iter << std::endl;
138
139 bool allConverged = true;
140
141 for (int i=0; i<NBatch; i++) {
142 //Compute double precision rsd and also new RHS vector.
143 Linop_d.HermOp(sol_d[i], tmp_d);
144 norm[i] = axpy_norm(src_d[i], -1., tmp_d, src_d_in[i]); //src_d is residual vector
145
146 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradientBatched: Outer iteration " << outer_iter <<" solve " << i << " residual "<< norm[i] << " target "<< stop[i] <<std::endl;
147
148 PrecChangeTimer.Start();
149 precisionChange(src_f[i], src_d[i]);
150 PrecChangeTimer.Stop();
151
152 sol_f[i] = Zero();
153
154 if(norm[i] > OuterLoopNormMult * stop[i]) {
155 allConverged = false;
156 }
157 }
158 if (allConverged) break;
159
160 if (updateResidual) {
161 RealD normMax = *std::max_element(std::begin(norm), std::end(norm));
162 RealD stopMax = *std::max_element(std::begin(stop), std::end(stop));
163 while( normMax * inner_tol * inner_tol < stopMax) inner_tol *= 2; // inner_tol = sqrt(stop/norm) ??
164 CG_f.Tolerance = inner_tol;
165 }
166
167 //Optionally improve inner solver guess (eg using known eigenvectors)
168 if(guesser != NULL) {
169 (*guesser)(src_f, sol_f);
170 }
171
172 for (int i=0; i<NBatch; i++) {
173 //Inner CG
174 InnerCGtimer.Start();
175 CG_f(Linop_f, src_f[i], sol_f[i]);
176 InnerCGtimer.Stop();
177 TotalInnerIterations[i] += CG_f.IterationsToComplete;
178
179 //Convert sol back to double and add to double prec solution
180 PrecChangeTimer.Start();
181 precisionChange(tmp_d, sol_f[i]);
182 PrecChangeTimer.Stop();
183
184 axpy(sol_d[i], 1.0, tmp_d, sol_d[i]);
185 }
186
187 }
188
189 //Final trial CG
190 std::cout << GridLogMessage << std::endl;
191 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradientBatched: Starting final patch-up double-precision solve"<<std::endl;
192
193 for (int i=0; i<NBatch; i++) {
195 CG_d(Linop_d, src_d_in[i], sol_d[i]);
196 TotalFinalStepIterations[i] += CG_d.IterationsToComplete;
197 }
198
199 TotalTimer.Stop();
200
201 std::cout << GridLogMessage << std::endl;
202 for (int i=0; i<NBatch; i++) {
203 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradientBatched: solve " << i << " Inner CG iterations " << TotalInnerIterations[i] << " Restarts " << TotalOuterIterations << " Final CG iterations " << TotalFinalStepIterations[i] << std::endl;
204 }
205 std::cout << GridLogMessage << std::endl;
206 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradientBatched: Total time " << TotalTimer.Elapsed() << " Precision change " << PrecChangeTimer.Elapsed() << " Inner CG total " << InnerCGtimer.Elapsed() << std::endl;
207
208 }
209};
210
212
213#endif
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
RealD norm2(const Lattice< vobj > &arg)
void precisionChange(Lattice< VobjOut > &out, const Lattice< VobjIn > &in, const precisionChangeWorkspace &workspace)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
MixedPrecisionConjugateGradientBatched(RealD tol, Integer maxinnerit, Integer maxouterit, Integer maxpatchit, GridBase *_sp_grid, LinearOperatorBase< FieldF > &_Linop_f, LinearOperatorBase< FieldD > &_Linop_d, bool _updateResidual=true)
void operator()(const FieldD &src_d_in, FieldD &sol_d)
Definition Simd.h:194