Grid 0.7.0
BiCGSTABMixedPrec.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/algorithms/iterative/BiCGSTABMixedPrec.h
6
7Copyright (C) 2015
8
9Author: Christopher Kelly <ckelly@phys.columbia.edu>
10Author: David Murphy <djmurphy@mit.edu>
11
12This program is free software; you can redistribute it and/or modify
13it under the terms of the GNU General Public License as published by
14the Free Software Foundation; either version 2 of the License, or
15(at your option) any later version.
16
17This program is distributed in the hope that it will be useful,
18but WITHOUT ANY WARRANTY; without even the implied warranty of
19MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20GNU General Public License for more details.
21
22You should have received a copy of the GNU General Public License along
23with this program; if not, write to the Free Software Foundation, Inc.,
2451 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25
26See the full license in the file "LICENSE" in the top level distribution directory
27*************************************************************************************/
28/* END LEGAL */
29
30#ifndef GRID_BICGSTAB_MIXED_PREC_H
31#define GRID_BICGSTAB_MIXED_PREC_H
32
34
35// Mixed precision restarted defect correction BiCGSTAB
36template<class FieldD, class FieldF, typename std::enable_if< getPrecision<FieldD>::value == 2, int>::type = 0, typename std::enable_if< getPrecision<FieldF>::value == 1, int>::type = 0>
38{
39 public:
40 using LinearFunction<FieldD>::operator();
42 RealD InnerTolerance; // Initial tolerance for inner CG. Defaults to Tolerance but can be changed
45 GridBase* SinglePrecGrid; // Grid for single-precision fields
46 RealD OuterLoopNormMult; // Stop the outer loop and move to a final double prec solve when the residual is OuterLoopNormMult * Tolerance
49
50 Integer TotalInnerIterations; //Number of inner CG iterations
51 Integer TotalOuterIterations; //Number of restarts
52 Integer TotalFinalStepIterations; //Number of CG iterations in final patch-up step
53
54 //Option to speed up *inner single precision* solves using a LinearFunction that produces a guess
56
57 MixedPrecisionBiCGSTAB(RealD tol, Integer maxinnerit, Integer maxouterit, GridBase* _sp_grid,
59 Linop_f(_Linop_f), Linop_d(_Linop_d), Tolerance(tol), InnerTolerance(tol), MaxInnerIterations(maxinnerit),
60 MaxOuterIterations(maxouterit), SinglePrecGrid(_sp_grid), OuterLoopNormMult(100.), guesser(NULL) {};
61
63 guesser = &g;
64 }
65
66 void operator() (const FieldD& src_d_in, FieldD& sol_d)
67 {
69
70 GridStopWatch TotalTimer;
71 TotalTimer.Start();
72
73 int cb = src_d_in.Checkerboard();
74 sol_d.Checkerboard() = cb;
75
76 RealD src_norm = norm2(src_d_in);
77 RealD stop = src_norm * Tolerance*Tolerance;
78
79 GridBase* DoublePrecGrid = src_d_in.Grid();
80 FieldD tmp_d(DoublePrecGrid);
81 tmp_d.Checkerboard() = cb;
82
83 FieldD tmp2_d(DoublePrecGrid);
84 tmp2_d.Checkerboard() = cb;
85
86 FieldD src_d(DoublePrecGrid);
87 src_d = src_d_in; //source for next inner iteration, computed from residual during operation
88
89 RealD inner_tol = InnerTolerance;
90
91 FieldF src_f(SinglePrecGrid);
92 src_f.Checkerboard() = cb;
93
94 FieldF sol_f(SinglePrecGrid);
95 sol_f.Checkerboard() = cb;
96
97 BiCGSTAB<FieldF> CG_f(inner_tol, MaxInnerIterations);
98 CG_f.ErrorOnNoConverge = false;
99
100 GridStopWatch InnerCGtimer;
101
102 GridStopWatch PrecChangeTimer;
103
104 Integer &outer_iter = TotalOuterIterations; //so it will be equal to the final iteration count
105
106 for(outer_iter = 0; outer_iter < MaxOuterIterations; outer_iter++)
107 {
108 // Compute double precision rsd and also new RHS vector.
109 Linop_d.Op(sol_d, tmp_d);
110 RealD norm = axpy_norm(src_d, -1., tmp_d, src_d_in); //src_d is residual vector
111
112 std::cout << GridLogMessage << "MixedPrecisionBiCGSTAB: Outer iteration " << outer_iter << " residual " << norm << " target " << stop << std::endl;
113
114 if(norm < OuterLoopNormMult * stop){
115 std::cout << GridLogMessage << "MixedPrecisionBiCGSTAB: Outer iteration converged on iteration " << outer_iter << std::endl;
116 break;
117 }
118 while(norm * inner_tol * inner_tol < stop){ inner_tol *= 2; } // inner_tol = sqrt(stop/norm) ??
119
120 PrecChangeTimer.Start();
121 precisionChange(src_f, src_d);
122 PrecChangeTimer.Stop();
123
124 sol_f = Zero();
125
126 //Optionally improve inner solver guess (eg using known eigenvectors)
127 if(guesser != NULL){ (*guesser)(src_f, sol_f); }
128
129 //Inner CG
130 CG_f.Tolerance = inner_tol;
131 InnerCGtimer.Start();
132 CG_f(Linop_f, src_f, sol_f);
133 InnerCGtimer.Stop();
135
136 //Convert sol back to double and add to double prec solution
137 PrecChangeTimer.Start();
138 precisionChange(tmp_d, sol_f);
139 PrecChangeTimer.Stop();
140
141 axpy(sol_d, 1.0, tmp_d, sol_d);
142 }
143
144 //Final trial CG
145 std::cout << GridLogMessage << "MixedPrecisionBiCGSTAB: Starting final patch-up double-precision solve" << std::endl;
146
148 CG_d(Linop_d, src_d_in, sol_d);
150
151 TotalTimer.Stop();
152 std::cout << GridLogMessage << "MixedPrecisionBiCGSTAB: Inner CG iterations " << TotalInnerIterations << " Restarts " << TotalOuterIterations << " Final CG iterations " << TotalFinalStepIterations << std::endl;
153 std::cout << GridLogMessage << "MixedPrecisionBiCGSTAB: Total time " << TotalTimer.Elapsed() << " Precision change " << PrecChangeTimer.Elapsed() << " Inner CG total " << InnerCGtimer.Elapsed() << std::endl;
154 }
155};
156
158
159#endif
RealD axpy_norm(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
void axpy(Lattice< vobj > &ret, sobj a, const Lattice< vobj > &x, const Lattice< vobj > &y)
RealD norm2(const Lattice< vobj > &arg)
void precisionChange(Lattice< VobjOut > &out, const Lattice< VobjIn > &in, const precisionChangeWorkspace &workspace)
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
double RealD
Definition Simd.h:61
bool ErrorOnNoConverge
Definition BiCGSTAB.h:50
RealD Tolerance
Definition BiCGSTAB.h:52
Integer IterationsToComplete
Definition BiCGSTAB.h:54
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
LinearFunction< FieldF > * guesser
void operator()(const FieldD &src_d_in, FieldD &sol_d)
LinearOperatorBase< FieldD > & Linop_d
LinearOperatorBase< FieldF > & Linop_f
void useGuesser(LinearFunction< FieldF > &g)
MixedPrecisionBiCGSTAB(RealD tol, Integer maxinnerit, Integer maxouterit, GridBase *_sp_grid, LinearOperatorBase< FieldF > &_Linop_f, LinearOperatorBase< FieldD > &_Linop_d)
Definition Simd.h:194