Grid 0.7.0
FourierAcceleratedPV.h
Go to the documentation of this file.
1
2 /*************************************************************************************
3
4 Grid physics library, www.github.com/paboyle/Grid
5
6 Source file: ./lib/qcd/action/fermion/FourierAcceleratedPV.h
7
8 Copyright (C) 2015
9
10Author: Christoph Lehner (lifted with permission by Peter Boyle, brought back to Grid)
11Author: Peter Boyle <pabobyle@ph.ed.ac.uk>
12
13 This program is free software; you can redistribute it and/or modify
14 it under the terms of the GNU General Public License as published by
15 the Free Software Foundation; either version 2 of the License, or
16 (at your option) any later version.
17
18 This program is distributed in the hope that it will be useful,
19 but WITHOUT ANY WARRANTY; without even the implied warranty of
20 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
21 GNU General Public License for more details.
22
23 You should have received a copy of the GNU General Public License along
24 with this program; if not, write to the Free Software Foundation, Inc.,
25 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
26
27 See the full license in the file "LICENSE" in the top level distribution directory
28 *************************************************************************************/
29 /* END LEGAL */
30#pragma once
31
33
34 template<typename M>
35 void get_real_const_bc(M& m, RealD& _b, RealD& _c) {
36 ComplexD b,c;
37 b=m.bs[0];
38 c=m.cs[0];
39 std::cout << GridLogMessage << "b=" << b << ", c=" << c << std::endl;
40 for (size_t i=1;i<m.bs.size();i++) {
41 assert(m.bs[i] == b);
42 assert(m.cs[i] == c);
43 }
44 assert(b.imag() == 0.0);
45 assert(c.imag() == 0.0);
46 _b = b.real();
47 _c = c.real();
48 }
49
50
51template<typename Vi, typename M, typename G>
53 public:
54
57 G& Umu;
61
62 FourierAcceleratedPV(M& _dwfPV, G& _Umu, ConjugateGradient<Vi> &_cg, int _group_in_s = 2)
63 : dwfPV(_dwfPV), Umu(_Umu), cg(_cg), group_in_s(_group_in_s)
64 {
65 assert( dwfPV.FermionGrid()->_fdimensions[0] % (2*group_in_s) == 0);
68 }
69
70 void rotatePV(const Vi& _src, Vi& dst, bool forward) const {
71
72 GridStopWatch gsw1, gsw2;
73
74 typedef typename Vi::scalar_type Coeff_t;
75 int Ls = dst.Grid()->_fdimensions[0];
76
77 Vi _tmp(dst.Grid());
78 double phase = M_PI / (double)Ls;
79 Coeff_t bzero(0.0,0.0);
80
81 FFT theFFT((GridCartesian*)dst.Grid());
82
83 if (!forward) {
84 gsw1.Start();
85 for (int s=0;s<Ls;s++) {
86 Coeff_t a(::cos(phase*s),-::sin(phase*s));
87 axpby_ssp(_tmp,a,_src,bzero,_src,s,s);
88 }
89 gsw1.Stop();
90
91 gsw2.Start();
92 theFFT.FFT_dim(dst,_tmp,0,FFT::forward);
93 gsw2.Stop();
94
95 } else {
96
97 gsw2.Start();
98 theFFT.FFT_dim(_tmp,_src,0,FFT::backward);
99 gsw2.Stop();
100
101 gsw1.Start();
102 for (int s=0;s<Ls;s++) {
103 Coeff_t a(::cos(phase*s),::sin(phase*s));
104 axpby_ssp(dst,a,_tmp,bzero,_tmp,s,s);
105 }
106 gsw1.Stop();
107 }
108
109 std::cout << GridLogMessage << "Timing rotatePV: " << gsw1.Elapsed() << ", " << gsw2.Elapsed() << std::endl;
110
111 }
112
113 void pvInv(const Vi& _src, Vi& _dst) const {
114
115 std::cout << GridLogMessage << "Fourier-Accelerated Outer Pauli Villars"<<std::endl;
116
117 typedef typename Vi::scalar_type Coeff_t;
118 int Ls = _dst.Grid()->_fdimensions[0];
119
120 GridStopWatch gswT;
121 gswT.Start();
122
123 RealD b,c;
125 RealD M5 = dwfPV.M5;
126
127 // U(true) Rightinv TMinv U(false) = Minv
128
129 Vi _src_diag(_dst.Grid());
130 Vi _src_diag_slice(dwfPV.GaugeGrid());
131 Vi _dst_diag_slice(dwfPV.GaugeGrid());
132 Vi _src_diag_slices(grid5D);
133 Vi _dst_diag_slices(grid5D);
134 Vi _dst_diag(_dst.Grid());
135
136 rotatePV(_src,_src_diag,false);
137
138 // now do TM solves
139 Gamma G5(Gamma::Algebra::Gamma5);
140
141 GridStopWatch gswA, gswB;
142
143 gswA.Start();
144
145 typedef typename M::Impl_t Impl;
146 //WilsonTMFermion<Impl> tm(x.Umu,*x.UGridF,*x.UrbGridF,0.0,0.0,solver_outer.parent.par.wparams_f);
147 std::vector<RealD> vmass(grid5D->_fdimensions[0],0.0);
148 std::vector<RealD> vmu(grid5D->_fdimensions[0],0.0);
149
151 *(GridCartesian*)dwfPV.GaugeGrid(),
152 *(GridRedBlackCartesian*)dwfPV.GaugeRedBlackGrid(),
153 vmass,vmu);
154
155 //SchurRedBlackDiagTwoSolve<Vi> sol(cg);
156 SchurRedBlackDiagMooeeSolve<Vi> sol(cg); // same performance as DiagTwo
157 gswA.Stop();
158
159 gswB.Start();
160
161 for (int sgroup=0;sgroup<Ls/2/group_in_s;sgroup++) {
162
163 for (int sidx=0;sidx<group_in_s;sidx++) {
164
165 int s = sgroup*group_in_s + sidx;
166 // int sprime = Ls-s-1;
167
168 RealD phase = M_PI / (RealD)Ls * (2.0 * s + 1.0);
169 RealD cosp = ::cos(phase);
170 RealD sinp = ::sin(phase);
171 RealD denom = b*b + c*c + 2.0*b*c*cosp;
172 RealD mass = -(b*b*M5 + c*(1.0 - cosp + c*M5) + b*(-1.0 + cosp + 2.0*c*cosp*M5))/denom;
173 RealD mu = (b+c)*sinp/denom;
174
175 vmass[2*sidx + 0] = mass;
176 vmass[2*sidx + 1] = mass;
177 vmu[2*sidx + 0] = mu;
178 vmu[2*sidx + 1] = -mu;
179
180 }
181
182 tm.update(vmass,vmu);
183
184 for (int sidx=0;sidx<group_in_s;sidx++) {
185
186 int s = sgroup*group_in_s + sidx;
187 int sprime = Ls-s-1;
188
189 ExtractSlice(_src_diag_slice,_src_diag,s,0);
190 InsertSlice(_src_diag_slice,_src_diag_slices,2*sidx + 0,0);
191
192 ExtractSlice(_src_diag_slice,_src_diag,sprime,0);
193 InsertSlice(_src_diag_slice,_src_diag_slices,2*sidx + 1,0);
194
195 }
196
197 GridStopWatch gsw;
198 gsw.Start();
199 _dst_diag_slices = Zero(); // zero guess
200 sol(tm,_src_diag_slices,_dst_diag_slices);
201 gsw.Stop();
202 std::cout << GridLogMessage << "Solve[sgroup=" << sgroup << "] completed in " << gsw.Elapsed() << ", " << gswA.Elapsed() << std::endl;
203
204 for (int sidx=0;sidx<group_in_s;sidx++) {
205
206 int s = sgroup*group_in_s + sidx;
207 int sprime = Ls-s-1;
208
209 RealD phase = M_PI / (RealD)Ls * (2.0 * s + 1.0);
210 RealD cosp = ::cos(phase);
211 RealD sinp = ::sin(phase);
212
213 // now rotate with inverse of
214 Coeff_t pA = b + c*cosp;
215 Coeff_t pB = - Coeff_t(0.0,1.0)*Coeff_t(c*sinp);
216 Coeff_t pABden = pA*pA - pB*pB;
217 // (pA + pB * G5) * (pA - pB*G5) = (pA^2 - pB^2)
218
219 ExtractSlice(_dst_diag_slice,_dst_diag_slices,2*sidx + 0,0);
220 _dst_diag_slice = (pA/pABden) * _dst_diag_slice - (pB/pABden) * (G5 * _dst_diag_slice);
221 InsertSlice(_dst_diag_slice,_dst_diag,s,0);
222
223 ExtractSlice(_dst_diag_slice,_dst_diag_slices,2*sidx + 1,0);
224 _dst_diag_slice = (pA/pABden) * _dst_diag_slice + (pB/pABden) * (G5 * _dst_diag_slice);
225 InsertSlice(_dst_diag_slice,_dst_diag,sprime,0);
226 }
227 }
228 gswB.Stop();
229
230 rotatePV(_dst_diag,_dst,true);
231
232 gswT.Stop();
233 std::cout << GridLogMessage << "PV completed in " << gswT.Elapsed() << " (Setup: " << gswA.Elapsed() << ", s-loop: " << gswB.Elapsed() << ")" << std::endl;
234 }
235
236};
238
void get_real_const_bc(M &m, RealD &_b, RealD &_c)
accelerator_inline Grid_simd< S, V > cos(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > sin(const Grid_simd< S, V > &r)
void InsertSlice(const Lattice< vobj > &lowDim, Lattice< vobj > &higherDim, int slice, int orthog)
void ExtractSlice(Lattice< vobj > &lowDim, const Lattice< vobj > &higherDim, int slice, int orthog)
void axpby_ssp(Lattice< vobj > &z, Coeff a, const Lattice< vobj > &x, Coeff b, const Lattice< vobj > &y, int s, int sp)
Definition LinalgUtils.h:59
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
#define M_PI
Definition Zolotarev.cc:41
Definition FFT.h:105
void FFT_dim(Lattice< vobj > &result, const Lattice< vobj > &source, int dim, int sign)
Definition FFT.h:169
static const int backward
Definition FFT.h:123
static const int forward
Definition FFT.h:122
FourierAcceleratedPV(M &_dwfPV, G &_Umu, ConjugateGradient< Vi > &_cg, int _group_in_s=2)
GridRedBlackCartesian * gridRB5D
ConjugateGradient< Vi > & cg
void rotatePV(const Vi &_src, Vi &dst, bool forward) const
void pvInv(const Vi &_src, Vi &_dst) const
Definition Gamma.h:10
void Start(void)
Definition Timer.h:92
GridTime Elapsed(void) const
Definition Timer.h:113
void Stop(void)
Definition Timer.h:99
static GridCartesian * makeFiveDimGrid(int Ls, const GridCartesian *FourDimGrid)
static GridRedBlackCartesian * makeFiveDimRedBlackGrid(int Ls, const GridCartesian *FourDimGrid)
void update(const std::vector< RealD > &_mass, const std::vector< RealD > &_mu)
Definition Simd.h:194