Grid 0.7.0
ConjugateGradientMixedPrec.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/algorithms/iterative/ConjugateGradientMixedPrec.h
6
7 Copyright (C) 2015
8
9Author: Christopher Kelly <ckelly@phys.columbia.edu>
10
11 This program is free software; you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation; either version 2 of the License, or
14 (at your option) any later version.
15
16 This program is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
20
21 You should have received a copy of the GNU General Public License along
22 with this program; if not, write to the Free Software Foundation, Inc.,
23 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25 See the full license in the file "LICENSE" in the top level distribution directory
26*************************************************************************************/
27/* END LEGAL */
28#ifndef GRID_CONJUGATE_GRADIENT_MIXED_PREC_H
29#define GRID_CONJUGATE_GRADIENT_MIXED_PREC_H
30
32
33 //Mixed precision restarted defect correction CG
34 template<class FieldD,class FieldF,
35 typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0,
36 typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
38 public:
39 using LinearFunction<FieldD>::operator();
41 RealD InnerTolerance; //Initial tolerance for inner CG. Defaults to Tolerance but can be changed
44 GridBase* SinglePrecGrid; //Grid for single-precision fields
45 RealD OuterLoopNormMult; //Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance
48
49 Integer TotalInnerIterations; //Number of inner CG iterations
50 Integer TotalOuterIterations; //Number of restarts
51 Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step
53
54 //Option to speed up *inner single precision* solves using a LinearFunction that produces a guess
56
58 Integer maxinnerit,
59 Integer maxouterit,
60 GridBase* _sp_grid,
63 Linop_f(_Linop_f), Linop_d(_Linop_d),
64 Tolerance(tol), InnerTolerance(tol), MaxInnerIterations(maxinnerit), MaxOuterIterations(maxouterit), SinglePrecGrid(_sp_grid),
65 OuterLoopNormMult(100.), guesser(NULL){ };
66
68 guesser = &g;
69 }
70
71 void operator() (const FieldD &src_d_in, FieldD &sol_d){
72 std::cout << GridLogMessage << "MixedPrecisionConjugateGradient: Starting mixed precision CG with outer tolerance " << Tolerance << " and inner tolerance " << InnerTolerance << std::endl;
74
75 GridStopWatch TotalTimer;
76 TotalTimer.Start();
77
78 int cb = src_d_in.Checkerboard();
79 sol_d.Checkerboard() = cb;
80
81 RealD src_norm = norm2(src_d_in);
82 RealD stop = src_norm * Tolerance*Tolerance;
83
84 GridBase* DoublePrecGrid = src_d_in.Grid();
85 FieldD tmp_d(DoublePrecGrid);
86 tmp_d.Checkerboard() = cb;
87
88 FieldD tmp2_d(DoublePrecGrid);
89 tmp2_d.Checkerboard() = cb;
90
91 FieldD src_d(DoublePrecGrid);
92 src_d = src_d_in; //source for next inner iteration, computed from residual during operation
93
94 RealD inner_tol = InnerTolerance;
95
96 FieldF src_f(SinglePrecGrid);
97 src_f.Checkerboard() = cb;
98
99 FieldF sol_f(SinglePrecGrid);
100 sol_f.Checkerboard() = cb;
101
102 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Starting initial inner CG with tolerance " << inner_tol << std::endl;
104 CG_f.ErrorOnNoConverge = false;
105
106 GridStopWatch InnerCGtimer;
107
108 GridStopWatch PrecChangeTimer;
109
110 Integer &outer_iter = TotalOuterIterations; //so it will be equal to the final iteration count
111
112 precisionChangeWorkspace pc_wk_sp_to_dp(DoublePrecGrid, SinglePrecGrid);
113 precisionChangeWorkspace pc_wk_dp_to_sp(SinglePrecGrid, DoublePrecGrid);
114
115 for(outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++){
116 //Compute double precision rsd and also new RHS vector.
117 Linop_d.HermOp(sol_d, tmp_d);
118 RealD norm = axpy_norm(src_d, -1., tmp_d, src_d_in); //src_d is residual vector
119 std::cout<<GridLogMessage<<" rsd norm "<<norm<<std::endl;
120 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Outer iteration " <<outer_iter<<" residual "<< norm<< " target "<< stop<<std::endl;
121
122 if(norm < OuterLoopNormMult * stop){
123 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Outer iteration converged on iteration " <<outer_iter <<std::endl;
124 break;
125 }
126 while(norm * inner_tol * inner_tol < stop*1.01) inner_tol *= 2; // inner_tol = sqrt(stop/norm) ??
127
128 PrecChangeTimer.Start();
129 precisionChange(src_f, src_d, pc_wk_dp_to_sp);
130 PrecChangeTimer.Stop();
131
132 sol_f = Zero();
133
134 //Optionally improve inner solver guess (eg using known eigenvectors)
135 if(guesser != NULL)
136 (*guesser)(src_f, sol_f);
137
138 //Inner CG
139 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Outer iteration " << outer_iter << " starting inner CG with tolerance " << inner_tol << std::endl;
140 CG_f.Tolerance = inner_tol;
141 InnerCGtimer.Start();
142 CG_f(Linop_f, src_f, sol_f);
143 InnerCGtimer.Stop();
145
146 //Convert sol back to double and add to double prec solution
147 PrecChangeTimer.Start();
148 precisionChange(tmp_d, sol_f, pc_wk_sp_to_dp);
149 PrecChangeTimer.Stop();
150
151 axpy(sol_d, 1.0, tmp_d, sol_d);
152 }
153
154 //Final trial CG
155 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Starting final patch-up double-precision solve"<<std::endl;
156
158 CG_d(Linop_d, src_d_in, sol_d);
161
162 TotalTimer.Stop();
163 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Inner CG iterations " << TotalInnerIterations << " Restarts " << TotalOuterIterations << " Final CG iterations " << TotalFinalStepIterations << std::endl;
164 std::cout<<GridLogMessage<<"MixedPrecisionConjugateGradient: Total time " << TotalTimer.Elapsed() << " Precision change " << PrecChangeTimer.Elapsed() << " Inner CG total " << InnerCGtimer.Elapsed() << std::endl;
165 }
166};
167
169
170#endif
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
RealD norm2(const Lattice< vobj > &arg)
void precisionChange(Lattice< VobjOut > &out, const Lattice< VobjIn > &in, const precisionChangeWorkspace &workspace)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
double RealD
Definition Simd.h:61
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
void operator()(const FieldD &src_d_in, FieldD &sol_d)
LinearOperatorBase< FieldF > & Linop_f
LinearOperatorBase< FieldD > & Linop_d
MixedPrecisionConjugateGradient(RealD tol, Integer maxinnerit, Integer maxouterit, GridBase *_sp_grid, LinearOperatorBase< FieldF > &_Linop_f, LinearOperatorBase< FieldD > &_Linop_d)
void useGuesser(LinearFunction< FieldF > &g)
Definition Simd.h:194