71 std::vector<FieldD> srcs_d_in{src_d_in};
72 std::vector<FieldD> sols_d{sol_d};
74 (*this)(srcs_d_in,sols_d);
79 void operator() (
const std::vector<FieldD> &src_d_in, std::vector<FieldD> &sol_d){
80 assert(src_d_in.size() == sol_d.size());
81 int NBatch = src_d_in.size();
85 Integer TotalOuterIterations = 0;
86 std::vector<Integer> TotalInnerIterations(NBatch,0);
87 std::vector<Integer> TotalFinalStepIterations(NBatch,0);
95 int cb = src_d_in[0].Checkerboard();
97 std::vector<RealD> src_norm;
98 std::vector<RealD> norm;
99 std::vector<RealD> stop;
101 GridBase* DoublePrecGrid = src_d_in[0].Grid();
102 FieldD tmp_d(DoublePrecGrid);
103 tmp_d.Checkerboard() = cb;
105 FieldD tmp2_d(DoublePrecGrid);
106 tmp2_d.Checkerboard() = cb;
108 std::vector<FieldD> src_d;
109 std::vector<FieldF> src_f;
110 std::vector<FieldF> sol_f;
112 for (
int i=0; i<NBatch; i++) {
113 sol_d[i].Checkerboard() = cb;
115 src_norm.push_back(
norm2(src_d_in[i]));
119 src_d.push_back(src_d_in[i]);
122 src_f[i].Checkerboard() = cb;
125 sol_f[i].Checkerboard() = cb;
133 Integer &outer_iter = TotalOuterIterations;
137 std::cout <<
GridLogMessage <<
"Outer iteration " << outer_iter << std::endl;
139 bool allConverged =
true;
141 for (
int i=0; i<NBatch; i++) {
143 Linop_d.HermOp(sol_d[i], tmp_d);
144 norm[i] =
axpy_norm(src_d[i], -1., tmp_d, src_d_in[i]);
146 std::cout<<
GridLogMessage<<
"MixedPrecisionConjugateGradientBatched: Outer iteration " << outer_iter <<
" solve " << i <<
" residual "<< norm[i] <<
" target "<< stop[i] <<std::endl;
148 PrecChangeTimer.
Start();
150 PrecChangeTimer.
Stop();
155 allConverged =
false;
158 if (allConverged)
break;
161 RealD normMax = *std::max_element(std::begin(norm), std::end(norm));
162 RealD stopMax = *std::max_element(std::begin(stop), std::end(stop));
163 while( normMax * inner_tol * inner_tol < stopMax) inner_tol *= 2;
169 (*guesser)(src_f, sol_f);
172 for (
int i=0; i<NBatch; i++) {
174 InnerCGtimer.
Start();
175 CG_f(
Linop_f, src_f[i], sol_f[i]);
180 PrecChangeTimer.
Start();
182 PrecChangeTimer.
Stop();
184 axpy(sol_d[i], 1.0, tmp_d, sol_d[i]);
191 std::cout<<
GridLogMessage<<
"MixedPrecisionConjugateGradientBatched: Starting final patch-up double-precision solve"<<std::endl;
193 for (
int i=0; i<NBatch; i++) {
195 CG_d(
Linop_d, src_d_in[i], sol_d[i]);
202 for (
int i=0; i<NBatch; i++) {
203 std::cout<<
GridLogMessage<<
"MixedPrecisionConjugateGradientBatched: solve " << i <<
" Inner CG iterations " << TotalInnerIterations[i] <<
" Restarts " << TotalOuterIterations <<
" Final CG iterations " << TotalFinalStepIterations[i] << std::endl;
206 std::cout<<
GridLogMessage<<
"MixedPrecisionConjugateGradientBatched: Total time " << TotalTimer.
Elapsed() <<
" Precision change " << PrecChangeTimer.
Elapsed() <<
" Inner CG total " << InnerCGtimer.
Elapsed() << std::endl;