Grid 0.7.0
FFT.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: ./lib/Cshift.h
6
7 Copyright (C) 2015
8
9Author: Peter Boyle <paboyle@ph.ed.ac.uk>
10
11 This program is free software; you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation; either version 2 of the License, or
14 (at your option) any later version.
15
16 This program is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
20
21 You should have received a copy of the GNU General Public License along
22 with this program; if not, write to the Free Software Foundation, Inc.,
23 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25 See the full license in the file "LICENSE" in the top level distribution directory
26*************************************************************************************/
27/* END LEGAL */
28#ifndef _GRID_FFT_H_
29#define _GRID_FFT_H_
30
31#ifdef HAVE_FFTW
32#if defined(USE_MKL) || defined(GRID_SYCL)
33#include <fftw/fftw3.h>
34#else
35#include <fftw3.h>
36#endif
37#endif
38
40
41template<class scalar> struct FFTW { };
42
43#ifdef HAVE_FFTW
44template<> struct FFTW<ComplexD> {
45public:
46
47 typedef fftw_complex FFTW_scalar;
48 typedef fftw_plan FFTW_plan;
49
50 static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
51 FFTW_scalar *in, const int *inembed,
52 int istride, int idist,
53 FFTW_scalar *out, const int *onembed,
54 int ostride, int odist,
55 int sign, unsigned flags) {
56 return ::fftw_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
57 }
58
59 static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
60 ::fftw_flops(p,add,mul,fmas);
61 }
62
63 inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
64 ::fftw_execute_dft(p,in,out);
65 }
66 inline static void fftw_destroy_plan(const FFTW_plan p) {
67 ::fftw_destroy_plan(p);
68 }
69};
70
71template<> struct FFTW<ComplexF> {
72public:
73
74 typedef fftwf_complex FFTW_scalar;
75 typedef fftwf_plan FFTW_plan;
76
77 static FFTW_plan fftw_plan_many_dft(int rank, const int *n,int howmany,
78 FFTW_scalar *in, const int *inembed,
79 int istride, int idist,
80 FFTW_scalar *out, const int *onembed,
81 int ostride, int odist,
82 int sign, unsigned flags) {
83 return ::fftwf_plan_many_dft(rank,n,howmany,in,inembed,istride,idist,out,onembed,ostride,odist,sign,flags);
84 }
85
86 static void fftw_flops(const FFTW_plan p,double *add, double *mul, double *fmas){
87 ::fftwf_flops(p,add,mul,fmas);
88 }
89
90 inline static void fftw_execute_dft(const FFTW_plan p,FFTW_scalar *in,FFTW_scalar *out) {
91 ::fftwf_execute_dft(p,in,out);
92 }
93 inline static void fftw_destroy_plan(const FFTW_plan p) {
94 ::fftwf_destroy_plan(p);
95 }
96};
97
98#endif
99
100#ifndef FFTW_FORWARD
101#define FFTW_FORWARD (-1)
102#define FFTW_BACKWARD (+1)
103#endif
104
105class FFT {
106private:
107
110
111 int Nd;
112 double flops;
114 uint64_t usec;
115
119
120public:
121
122 static const int forward=FFTW_FORWARD;
123 static const int backward=FFTW_BACKWARD;
124
125 double Flops(void) {return flops;}
126 double MFlops(void) {return flops/usec;}
127 double USec(void) {return (double)usec;}
128
129 FFT ( GridCartesian * grid ) :
130 vgrid(grid),
131 Nd(grid->_ndimension),
132 dimensions(grid->_fdimensions),
133 processors(grid->_processors),
134 processor_coor(grid->_processor_coor)
135 {
136 flops=0;
137 usec =0;
138 Coordinate layout(Nd,1);
139 sgrid = new GridCartesian(dimensions,layout,processors,*grid);
140 };
141
142 ~FFT ( void) {
143 delete sgrid;
144 }
145
146 template<class vobj>
147 void FFT_dim_mask(Lattice<vobj> &result,const Lattice<vobj> &source,Coordinate mask,int sign){
148
149 conformable(result.Grid(),vgrid);
150 conformable(source.Grid(),vgrid);
151 Lattice<vobj> tmp(vgrid);
152 tmp = source;
153 for(int d=0;d<Nd;d++){
154 if( mask[d] ) {
155 FFT_dim(result,tmp,d,sign);
156 tmp=result;
157 }
158 }
159 }
160
161 template<class vobj>
162 void FFT_all_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int sign){
163 Coordinate mask(Nd,1);
164 FFT_dim_mask(result,source,mask,sign);
165 }
166
167
168 template<class vobj>
169 void FFT_dim(Lattice<vobj> &result,const Lattice<vobj> &source,int dim, int sign){
170#ifndef HAVE_FFTW
171 std::cerr << "FFTW is not compiled but is called"<<std::endl;
172 assert(0);
173#else
174 conformable(result.Grid(),vgrid);
175 conformable(source.Grid(),vgrid);
176
177 int L = vgrid->_ldimensions[dim];
178 int G = vgrid->_fdimensions[dim];
179
180 Coordinate layout(Nd,1);
181 Coordinate pencil_gd(vgrid->_fdimensions);
182
183 pencil_gd[dim] = G*processors[dim];
184
185 // Pencil global vol LxLxGxLxL per node
186 GridCartesian pencil_g(pencil_gd,layout,processors,*vgrid);
187
188 // Construct pencils
189 typedef typename vobj::scalar_object sobj;
190 typedef typename sobj::scalar_type scalar;
191
192 Lattice<sobj> pgbuf(&pencil_g);
193 autoView(pgbuf_v , pgbuf, CpuWrite);
194 //std::cout << "CPU view" << std::endl;
195
196 typedef typename FFTW<scalar>::FFTW_scalar FFTW_scalar;
197 typedef typename FFTW<scalar>::FFTW_plan FFTW_plan;
198
199 int Ncomp = sizeof(sobj)/sizeof(scalar);
200 int Nlow = 1;
201 for(int d=0;d<dim;d++){
202 Nlow*=vgrid->_ldimensions[d];
203 }
204
205 int rank = 1; /* 1d transforms */
206 int n[] = {G}; /* 1d transforms of length G */
207 int howmany = Ncomp;
208 int odist,idist,istride,ostride;
209 idist = odist = 1; /* Distance between consecutive FT's */
210 istride = ostride = Ncomp*Nlow; /* distance between two elements in the same FT */
211 int *inembed = n, *onembed = n;
212
213 scalar div;
214 if ( sign == backward ) div = 1.0/G;
215 else if ( sign == forward ) div = 1.0;
216 else assert(0);
217
218 //std::cout << GridLogPerformance<<"Making FFTW plan" << std::endl;
219 FFTW_plan p;
220 {
221 FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[0];
222 FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[0];
223 p = FFTW<scalar>::fftw_plan_many_dft(rank,n,howmany,
224 in,inembed,
225 istride,idist,
226 out,onembed,
227 ostride, odist,
228 sign,FFTW_ESTIMATE);
229 }
230
231 // Barrel shift and collect global pencil
232 //std::cout << GridLogPerformance<<"Making pencil" << std::endl;
233 Coordinate lcoor(Nd), gcoor(Nd);
234 result = source;
235 int pc = processor_coor[dim];
236 for(int p=0;p<processors[dim];p++) {
237 {
238 autoView(r_v,result,CpuRead);
239 autoView(p_v,pgbuf,CpuWrite);
240 thread_for(idx, sgrid->lSites(),{
241 Coordinate cbuf(Nd);
242 sobj s;
243 sgrid->LocalIndexToLocalCoor(idx,cbuf);
244 peekLocalSite(s,r_v,cbuf);
245 cbuf[dim]+=((pc+p) % processors[dim])*L;
246 pokeLocalSite(s,p_v,cbuf);
247 });
248 }
249 if (p != processors[dim] - 1) {
250 result = Cshift(result,dim,L);
251 }
252 }
253
254 //std::cout <<GridLogPerformance<< "Looping orthog" << std::endl;
255 // Loop over orthog coords
256 int NN=pencil_g.lSites();
257 GridStopWatch timer;
258 timer.Start();
259 thread_for( idx,NN,{
260 Coordinate cbuf(Nd);
261 pencil_g.LocalIndexToLocalCoor(idx, cbuf);
262 if ( cbuf[dim] == 0 ) { // restricts loop to plane at lcoor[dim]==0
263 FFTW_scalar *in = (FFTW_scalar *)&pgbuf_v[idx];
264 FFTW_scalar *out= (FFTW_scalar *)&pgbuf_v[idx];
266 }
267 });
268 timer.Stop();
269
270 // performance counting
271 double add,mul,fma;
272 FFTW<scalar>::fftw_flops(p,&add,&mul,&fma);
273 flops_call = add+mul+2.0*fma;
274 usec += timer.useconds();
275 flops+= flops_call*NN;
276
277 //std::cout <<GridLogPerformance<< "Writing back results " << std::endl;
278 // writing out result
279 {
280 autoView(pgbuf_v,pgbuf,CpuRead);
281 autoView(result_v,result,CpuWrite);
282 thread_for(idx,sgrid->lSites(),{
283 Coordinate clbuf(Nd), cgbuf(Nd);
284 sobj s;
285 sgrid->LocalIndexToLocalCoor(idx,clbuf);
286 cgbuf = clbuf;
287 cgbuf[dim] = clbuf[dim]+L*pc;
288 peekLocalSite(s,pgbuf_v,cgbuf);
289 pokeLocalSite(s,result_v,clbuf);
290 });
291 }
292 result = result*div;
293
294 //std::cout <<GridLogPerformance<< "Destroying plan " << std::endl;
295 // destroying plan
297#endif
298 }
299};
300
302
303#endif
AcceleratorVector< int, MaxDims > Coordinate
Definition Coordinate.h:95
auto Cshift(const Expression &expr, int dim, int shift) -> decltype(closure(expr))
Definition Cshift.h:55
#define FFTW_FORWARD
Definition FFT.h:101
#define FFTW_BACKWARD
Definition FFT.h:102
void add(Lattice< obj1 > &ret, const Lattice< obj2 > &lhs, const Lattice< obj3 > &rhs)
void conformable(const Lattice< obj1 > &lhs, const Lattice< obj2 > &rhs)
Lattice< obj > div(const Lattice< obj > &rhs_i, Integer y)
#define autoView(l_v, l, mode)
@ CpuRead
@ CpuWrite
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
std::complex< RealF > ComplexF
Definition Simd.h:78
std::complex< RealD > ComplexD
Definition Simd.h:79
#define thread_for(i, num,...)
Definition Threads.h:60
Coordinate processor_coor
Definition FFT.h:118
double flops
Definition FFT.h:112
void FFT_dim(Lattice< vobj > &result, const Lattice< vobj > &source, int dim, int sign)
Definition FFT.h:169
~FFT(void)
Definition FFT.h:142
Coordinate dimensions
Definition FFT.h:116
static const int backward
Definition FFT.h:123
GridCartesian * vgrid
Definition FFT.h:108
void FFT_dim_mask(Lattice< vobj > &result, const Lattice< vobj > &source, Coordinate mask, int sign)
Definition FFT.h:147
static const int forward
Definition FFT.h:122
double Flops(void)
Definition FFT.h:125
double USec(void)
Definition FFT.h:127
uint64_t usec
Definition FFT.h:114
GridCartesian * sgrid
Definition FFT.h:109
Coordinate processors
Definition FFT.h:117
int Nd
Definition FFT.h:111
FFT(GridCartesian *grid)
Definition FFT.h:129
double MFlops(void)
Definition FFT.h:126
void FFT_all_dim(Lattice< vobj > &result, const Lattice< vobj > &source, int sign)
Definition FFT.h:162
double flops_call
Definition FFT.h:113
void LocalIndexToLocalCoor(int lidx, Coordinate &lcoor)
int lSites(void) const
void Start(void)
Definition Timer.h:92
uint64_t useconds(void) const
Definition Timer.h:117
void Stop(void)
Definition Timer.h:99
GridBase * Grid(void) const
Definition FFT.h:41