131 GRID_TRACE(
"ConjugateGradientMultiShiftMixedPrec");
132 GridBase *DoublePrecGrid = src_d.Grid();
140 int nshift =
shifts.order;
142 std::vector<RealD> &mass(
shifts.poles);
143 std::vector<RealD> &mresidual(
shifts.tolerances);
144 std::vector<RealD> alpha(nshift,1.0);
147 FieldD p_d(DoublePrecGrid);
148 std::vector<FieldD> ps_d(nshift, DoublePrecGrid);
150 FieldD tmp_d(DoublePrecGrid);
151 FieldD r_d(DoublePrecGrid);
152 FieldD mmp_d(DoublePrecGrid);
154 assert(psi_d.size()==nshift);
155 assert(mass.size()==nshift);
156 assert(mresidual.size()==nshift);
159 std::vector<RealD> bs(nshift);
160 std::vector<RealD> rsq(nshift);
161 std::vector<RealD> rsqf(nshift);
162 std::vector<std::array<RealD,2> > z(nshift);
163 std::vector<int> converged(nshift);
165 const int primary =0;
176 for(
int s=0;s<nshift;s++){
177 assert( mass[s]>= mass[primary] );
188 for(
int s=0;s<nshift;s++){
196 for(
int s=0;s<nshift;s++){
197 rsq[s] = cp * mresidual[s] * mresidual[s];
199 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: shift "<< s <<
" target resid "<<rsq[s]<<std::endl;
209 Linop_f.HermOpAndNorm(p_f,mmp_f,d,qq);
212 tmp_d = tmp_d - mmp_d;
213 std::cout <<
" Testing operators match "<<
norm2(mmp_d)<<
" f "<<
norm2(mmp_f)<<
" diff "<<
norm2(tmp_d)<<std::endl;
214 assert(
norm2(tmp_d)< 1.0);
216 axpy(mmp_d,mass[0],p_d,mmp_d);
227 for(
int s=1;s<nshift;s++){
229 z[s][iz] = 1.0/( 1.0 - b*(mass[s]-mass[0]));
237 for(
int s=0;s<nshift;s++) {
238 axpby(psi_d[s],0.,-bs[s]*alpha[s],src_d,src_d);
244 GridStopWatch AXPYTimer, ShiftTimer, QRTimer, MatrixTimer, SolverTimer, PrecChangeTimer, CleanupTimer;
257 for(
int s=0;s<nshift;s++){
258 if ( ! converged[s] ) {
260 axpy(ps_d[s],a,ps_d[s],r_d);
262 RealD as =a *z[s][iz]*bs[s] /(z[s][1-iz]*b);
263 axpby(ps_d[s],z[s][iz],as,r_d,ps_d[s]);
269 PrecChangeTimer.
Start();
271 PrecChangeTimer.
Stop();
278 PrecChangeTimer.
Start();
280 PrecChangeTimer.
Stop();
284 axpy(mmp_d,mass[0],p_d,mmp_d);
296 for(
int s=1;s<nshift;s++){
298 RealD z0 = z[s][1-iz];
301 / (b*a*(z1-z0) + z1*bp*(1- (mass[s]-mass[0])*b));
302 bs[s] = b*z[s][iz]/z0;
309 for(
int s=0;s<nshift;s++){
311 if( (!converged[s]) ) {
312 axpy(psi_d[ss],-bs[s]*alpha[s],ps_d[s],psi_d[ss]);
325 Linop_d.
HermOp(psi_d[0],mmp_d);
329 axpy(mmp_d,mass[0],psi_d[0],mmp_d);
334 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec k="<<k<<
", replaced |r|^2 = "<<c_old <<
" with |r|^2 = "<<c<<std::endl;
338 int all_converged = 1;
339 for(
int s=0;s<nshift;s++){
341 if ( (!converged[s]) ){
344 RealD css = c * z[s][iz]* z[s][iz];
347 if ( ! converged[s] )
348 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec k="<<k<<
" Shift "<<s<<
" has converged"<<std::endl;
361 if ( all_converged ){
362 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: All shifts have converged iteration "<<k<<std::endl;
363 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: Checking solutions"<<std::endl;
365 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: Not all shifts have converged iteration "<<k<<std::endl;
369 for(
int s=0; s < nshift; s++) {
371 axpy(tmp_d,mass[s],psi_d[s],mmp_d);
372 axpy(r_d,-alpha[s],src_d,tmp_d);
376 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: shift["<<s<<
"] true residual "<<
TrueResidualShift[s] <<
" target " << mresidual[s] << std::endl;
380 CleanupTimer.
Start();
381 std::cout<<
GridLogMessage<<
"ConjugateGradientMultiShiftMixedPrec: performing cleanup step for shift " << s << std::endl;
395 std::cout <<
GridLogMessage <<
"ConjugateGradientMultiShiftMixedPrec: Time Breakdown for body"<<std::endl;
410 std::cout<<
GridLogMessage<<
"CG multi shift did not converge"<<std::endl;