Grid 0.7.0
Lattice_ET.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/lattice/Lattice_ET.h
6
7Copyright (C) 2015
8
9Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
10Author: Peter Boyle <paboyle@ph.ed.ac.uk>
11Author: neo <cossu@post.kek.jp>
12Author: Christoph Lehner <christoph@lhnr.de
13
14This program is free software; you can redistribute it and/or modify
15it under the terms of the GNU General Public License as published by
16the Free Software Foundation; either version 2 of the License, or
17(at your option) any later version.
18
19This program is distributed in the hope that it will be useful,
20but WITHOUT ANY WARRANTY; without even the implied warranty of
21MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
22GNU General Public License for more details.
23
24You should have received a copy of the GNU General Public License along
25with this program; if not, write to the Free Software Foundation, Inc.,
2651 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
27
28See the full license in the file "LICENSE" in the top level distribution
29directory
30*************************************************************************************/
31 /* END LEGAL */
32#ifndef GRID_LATTICE_ET_H
33#define GRID_LATTICE_ET_H
34
35#include <iostream>
36#include <tuple>
37#include <typeinfo>
38#include <vector>
39
41
43// Predicated where support
45#ifdef GRID_SIMT
46// drop to scalar in SIMT; cleaner in fact
47template <class iobj, class vobj, class robj>
48accelerator_inline vobj predicatedWhere(const iobj &predicate,
49 const vobj &iftrue,
50 const robj &iffalse)
51{
52 Integer mask = TensorRemove(predicate);
53 typename std::remove_const<vobj>::type ret= iffalse;
54 if (mask) ret=iftrue;
55 return ret;
56}
57#else
58template <class iobj, class vobj, class robj>
59accelerator_inline vobj predicatedWhere(const iobj &predicate,
60 const vobj &iftrue,
61 const robj &iffalse)
62{
63 typename std::remove_const<vobj>::type ret;
64
65 typedef typename vobj::scalar_object scalar_object;
66 // typedef typename vobj::scalar_type scalar_type;
67 typedef typename vobj::vector_type vector_type;
68
69 const int Nsimd = vobj::vector_type::Nsimd();
70
71 ExtractBuffer<Integer> mask(Nsimd);
72 ExtractBuffer<scalar_object> truevals(Nsimd);
73 ExtractBuffer<scalar_object> falsevals(Nsimd);
74
75 extract(iftrue, truevals);
76 extract(iffalse, falsevals);
78
79 for (int s = 0; s < Nsimd; s++) {
80 if (mask[s]) falsevals[s] = truevals[s];
81 }
82
83 merge(ret, falsevals);
84 return ret;
85}
86#endif
87
89//Specialization of getVectorType for lattices
91template<typename T>
94};
95
97//-- recursive evaluation of expressions; --
98// handle leaves of syntax tree
100template<class sobj,
101 typename std::enable_if<!is_lattice<sobj>::value&&!is_lattice_expr<sobj>::value,sobj>::type * = nullptr>
103sobj eval(const uint64_t ss, const sobj &arg)
104{
105 return arg;
106}
107template <class lobj> accelerator_inline
108auto eval(const uint64_t ss, const LatticeView<lobj> &arg) -> decltype(arg(ss))
109{
110 return arg(ss);
111}
112
114//-- recursive evaluation of expressions; --
115// whole vector return, used only for expression return type inference
117template<class sobj> accelerator_inline
118sobj vecEval(const uint64_t ss, const sobj &arg)
119{
120 return arg;
121}
122template <class lobj> accelerator_inline
123const lobj & vecEval(const uint64_t ss, const LatticeView<lobj> &arg)
124{
125 return arg[ss];
126}
127
129// handle nodes in syntax tree- eval one operand
130// vecEval needed (but never called as all expressions offloaded) to infer the return type
131// in SIMT contexts of closure.
133template <typename Op, typename T1> accelerator_inline
134auto vecEval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
135 -> decltype(expr.op.func( vecEval(ss, expr.arg1)))
136{
137 return expr.op.func( vecEval(ss, expr.arg1) );
138}
139// vecEval two operands
140template <typename Op, typename T1, typename T2> accelerator_inline
141auto vecEval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
142 -> decltype(expr.op.func( vecEval(ss,expr.arg1),vecEval(ss,expr.arg2)))
143{
144 return expr.op.func( vecEval(ss,expr.arg1), vecEval(ss,expr.arg2) );
145}
146// vecEval three operands
147template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
148auto vecEval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
149 -> decltype(expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3)))
150{
151 return expr.op.func(vecEval(ss, expr.arg1), vecEval(ss, expr.arg2), vecEval(ss, expr.arg3));
152}
153
155// handle nodes in syntax tree- eval one operand coalesced
157template <typename Op, typename T1> accelerator_inline
158auto eval(const uint64_t ss, const LatticeUnaryExpression<Op, T1> &expr)
159 -> decltype(expr.op.func( eval(ss, expr.arg1)))
160{
161 return expr.op.func( eval(ss, expr.arg1) );
162}
163// eval two operands
164template <typename Op, typename T1, typename T2> accelerator_inline
165auto eval(const uint64_t ss, const LatticeBinaryExpression<Op, T1, T2> &expr)
166 -> decltype(expr.op.func( eval(ss,expr.arg1),eval(ss,expr.arg2)))
167{
168 return expr.op.func( eval(ss,expr.arg1), eval(ss,expr.arg2) );
169}
170// eval three operands
171template <typename Op, typename T1, typename T2, typename T3> accelerator_inline
172auto eval(const uint64_t ss, const LatticeTrinaryExpression<Op, T1, T2, T3> &expr)
173 -> decltype(expr.op.func(eval(ss, expr.arg1),
174 eval(ss, expr.arg2),
175 eval(ss, expr.arg3)))
176{
177#ifdef GRID_SIMT
178 // Handles Nsimd (vInteger) != Nsimd(ComplexD)
179 typedef decltype(vecEval(ss, expr.arg2)) rvobj;
180 typedef typename std::remove_reference<rvobj>::type vobj;
181
182 const int Nsimd = vobj::vector_type::Nsimd();
183
184 auto vpred = vecEval(ss,expr.arg1);
185
186 ExtractBuffer<Integer> mask(Nsimd);
188
189 int s = acceleratorSIMTlane(Nsimd);
190 return expr.op.func(mask[s],
191 eval(ss, expr.arg2),
192 eval(ss, expr.arg3));
193#else
194 return expr.op.func(eval(ss, expr.arg1),
195 eval(ss, expr.arg2),
196 eval(ss, expr.arg3));
197#endif
198}
199
201// Obtain the grid from an expression, ensuring conformable. This must follow a
202// tree recursion; must retain grid pointer in the LatticeView class which sucks
203// Use a different method, and make it void *.
204// Perhaps a conformable method.
206template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
207accelerator_inline void GridFromExpression(GridBase *&grid, const T1 &lat) // Lattice leaf
208{
209 lat.Conformable(grid);
210}
211
212template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
214void GridFromExpression(GridBase *&grid,const T1 &notlat) // non-lattice leaf
215{}
216
217template <typename Op, typename T1>
220{
221 GridFromExpression(grid, expr.arg1); // recurse
222}
223
224template <typename Op, typename T1, typename T2>
227{
228 GridFromExpression(grid, expr.arg1); // recurse
229 GridFromExpression(grid, expr.arg2);
230}
231template <typename Op, typename T1, typename T2, typename T3>
234{
235 GridFromExpression(grid, expr.arg1); // recurse
236 GridFromExpression(grid, expr.arg2); // recurse
237 GridFromExpression(grid, expr.arg3); // recurse
238}
239
241// Obtain the CB from an expression, ensuring conformable. This must follow a
242// tree recursion
244template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
245inline void CBFromExpression(int &cb, const T1 &lat) // Lattice leaf
246{
247 if ((cb == Odd) || (cb == Even)) {
248 assert(cb == lat.Checkerboard());
249 }
250 cb = lat.Checkerboard();
251}
252template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
253inline void CBFromExpression(int &cb, const T1 &notlat) {} // non-lattice leaf
254template <typename Op, typename T1> inline
256{
257 CBFromExpression(cb, expr.arg1); // recurse AST
258}
259template <typename Op, typename T1, typename T2> inline
261{
262 CBFromExpression(cb, expr.arg1); // recurse AST
263 CBFromExpression(cb, expr.arg2); // recurse AST
264}
265template <typename Op, typename T1, typename T2, typename T3>
267{
268 CBFromExpression(cb, expr.arg1); // recurse AST
269 CBFromExpression(cb, expr.arg2); // recurse AST
270 CBFromExpression(cb, expr.arg3); // recurse AST
271}
272
273
275// ViewOpen
277template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
278inline void ExpressionViewOpen(T1 &lat) // Lattice leaf
279{
280 lat.ViewOpen(AcceleratorRead);
281}
282template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
283 inline void ExpressionViewOpen(T1 &notlat) {}
284
285template <typename Op, typename T1> inline
287{
288 ExpressionViewOpen(expr.arg1); // recurse AST
289}
290
291template <typename Op, typename T1, typename T2> inline
293{
294 ExpressionViewOpen(expr.arg1); // recurse AST
295 ExpressionViewOpen(expr.arg2); // rrecurse AST
296}
297template <typename Op, typename T1, typename T2, typename T3>
299{
300 ExpressionViewOpen(expr.arg1); // recurse AST
301 ExpressionViewOpen(expr.arg2); // recurse AST
302 ExpressionViewOpen(expr.arg3); // recurse AST
303}
304
306// ViewClose
308template <class T1,typename std::enable_if<is_lattice<T1>::value, T1>::type * = nullptr>
309inline void ExpressionViewClose( T1 &lat) // Lattice leaf
310{
311 lat.ViewClose();
312}
313template <class T1,typename std::enable_if<!is_lattice<T1>::value, T1>::type * = nullptr>
314inline void ExpressionViewClose(T1 &notlat) {}
315
316template <typename Op, typename T1> inline
318{
319 ExpressionViewClose(expr.arg1); // recurse AST
320}
321template <typename Op, typename T1, typename T2> inline
323{
324 ExpressionViewClose(expr.arg1); // recurse AST
325 ExpressionViewClose(expr.arg2); // recurse AST
326}
327template <typename Op, typename T1, typename T2, typename T3>
329{
330 ExpressionViewClose(expr.arg1); // recurse AST
331 ExpressionViewClose(expr.arg2); // recurse AST
332 ExpressionViewClose(expr.arg3); // recurse AST
333}
334
336// Unary operators and funcs
338#define GridUnopClass(name, ret) \
339 struct name { \
340 template<class _arg> static auto accelerator_inline func(const _arg a) -> decltype(ret) { return ret; } \
341 };
342
343GridUnopClass(UnarySub, -a);
344GridUnopClass(UnaryNot, Not(a));
345GridUnopClass(UnaryTrace, trace(a));
346GridUnopClass(UnaryTranspose, transpose(a));
347GridUnopClass(UnaryTa, Ta(a));
348GridUnopClass(UnarySpTa, SpTa(a));
349GridUnopClass(UnaryProjectOnGroup, ProjectOnGroup(a));
350GridUnopClass(UnaryProjectOnSpGroup, ProjectOnSpGroup(a));
351GridUnopClass(UnaryTimesI, timesI(a));
352GridUnopClass(UnaryTimesMinusI, timesMinusI(a));
353GridUnopClass(UnaryAbs, abs(a));
354GridUnopClass(UnarySqrt, sqrt(a));
355GridUnopClass(UnarySin, sin(a));
356GridUnopClass(UnaryCos, cos(a));
357GridUnopClass(UnaryAsin, asin(a));
358GridUnopClass(UnaryAcos, acos(a));
359GridUnopClass(UnaryLog, log(a));
360GridUnopClass(UnaryExp, exp(a));
361
363// Binary operators
365#define GridBinOpClass(name, combination) \
366 struct name { \
367 template <class _left, class _right> \
368 static auto accelerator_inline \
369 func(const _left &lhs, const _right &rhs) \
370 -> decltype(combination) const \
371 { \
372 return combination; \
373 } \
374 };
375
376GridBinOpClass(BinaryAdd, lhs + rhs);
377GridBinOpClass(BinarySub, lhs - rhs);
378GridBinOpClass(BinaryMul, lhs *rhs);
379GridBinOpClass(BinaryDiv, lhs /rhs);
380GridBinOpClass(BinaryAnd, lhs &rhs);
381GridBinOpClass(BinaryOr, lhs | rhs);
382GridBinOpClass(BinaryAndAnd, lhs &&rhs);
383GridBinOpClass(BinaryOrOr, lhs || rhs);
384
386// Trinary conditional op
388#define GridTrinOpClass(name, combination) \
389 struct name { \
390 template <class _predicate,class _left, class _right> \
391 static auto accelerator_inline \
392 func(const _predicate &pred, const _left &lhs, const _right &rhs) \
393 -> decltype(combination) const \
394 { \
395 return combination; \
396 } \
397 };
398
399GridTrinOpClass(TrinaryWhere,
401 typename std::remove_reference<_predicate>::type,
402 typename std::remove_reference<_left>::type,
403 typename std::remove_reference<_right>::type>(pred, lhs,rhs)));
404
406// Operator syntactical glue
408#define GRID_UNOP(name) name
409#define GRID_BINOP(name) name
410#define GRID_TRINOP(name) name
411
412#define GRID_DEF_UNOP(op, name) \
413 template <typename T1, typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
414 inline auto op(const T1 &arg) ->decltype(LatticeUnaryExpression<GRID_UNOP(name),T1>(GRID_UNOP(name)(), arg)) \
415 { \
416 return LatticeUnaryExpression<GRID_UNOP(name),T1>(GRID_UNOP(name)(), arg); \
417 }
418
419#define GRID_BINOP_LEFT(op, name) \
420 template <typename T1, typename T2, \
421 typename std::enable_if<is_lattice<T1>::value||is_lattice_expr<T1>::value,T1>::type * = nullptr> \
422 inline auto op(const T1 &lhs, const T2 &rhs) \
423 ->decltype(LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs,rhs)) \
424 { \
425 return LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs,rhs);\
426 }
427
428#define GRID_BINOP_RIGHT(op, name) \
429 template <typename T1, typename T2, \
430 typename std::enable_if<!is_lattice<T1>::value&&!is_lattice_expr<T1>::value,T1>::type * = nullptr, \
431 typename std::enable_if< is_lattice<T2>::value|| is_lattice_expr<T2>::value,T2>::type * = nullptr> \
432 inline auto op(const T1 &lhs, const T2 &rhs) \
433 ->decltype(LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs, rhs)) \
434 { \
435 return LatticeBinaryExpression<GRID_BINOP(name),T1,T2>(GRID_BINOP(name)(),lhs, rhs); \
436 }
437
438#define GRID_DEF_BINOP(op, name) \
439 GRID_BINOP_LEFT(op, name); \
440 GRID_BINOP_RIGHT(op, name);
441
442#define GRID_DEF_TRINOP(op, name) \
443 template <typename T1, typename T2, typename T3> \
444 inline auto op(const T1 &pred, const T2 &lhs, const T3 &rhs) \
445 ->decltype(LatticeTrinaryExpression<GRID_TRINOP(name),T1,T2,T3>(GRID_TRINOP(name)(),pred, lhs, rhs)) \
446 { \
447 return LatticeTrinaryExpression<GRID_TRINOP(name),T1,T2,T3>(GRID_TRINOP(name)(),pred, lhs, rhs); \
448 }
449
451// Operator definitions
453GRID_DEF_UNOP(operator-, UnarySub);
455GRID_DEF_UNOP(operator!, UnaryNot);
456//GRID_DEF_UNOP(adj, UnaryAdj);
457//GRID_DEF_UNOP(conjugate, UnaryConj);
458GRID_DEF_UNOP(trace, UnaryTrace);
459GRID_DEF_UNOP(transpose, UnaryTranspose);
461GRID_DEF_UNOP(SpTa, UnarySpTa);
462GRID_DEF_UNOP(ProjectOnGroup, UnaryProjectOnGroup);
463GRID_DEF_UNOP(ProjectOnSpGroup, UnaryProjectOnSpGroup);
464GRID_DEF_UNOP(timesI, UnaryTimesI);
465GRID_DEF_UNOP(timesMinusI, UnaryTimesMinusI);
466GRID_DEF_UNOP(abs, UnaryAbs); // abs overloaded in cmath C++98; DON'T do the
467 // abs-fabs-dabs-labs thing
468GRID_DEF_UNOP(sqrt, UnarySqrt);
471GRID_DEF_UNOP(asin, UnaryAsin);
472GRID_DEF_UNOP(acos, UnaryAcos);
475
476GRID_DEF_BINOP(operator+, BinaryAdd);
477GRID_DEF_BINOP(operator-, BinarySub);
478GRID_DEF_BINOP(operator*, BinaryMul);
479GRID_DEF_BINOP(operator/, BinaryDiv);
480
481GRID_DEF_BINOP(operator&, BinaryAnd);
482GRID_DEF_BINOP(operator|, BinaryOr);
483GRID_DEF_BINOP(operator&&, BinaryAndAnd);
484GRID_DEF_BINOP(operator||, BinaryOrOr);
485
486GRID_DEF_TRINOP(where, TrinaryWhere);
487
489// Closure convenience to force expression to evaluate
491template <class Op, class T1>
493 -> Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1)))>::type >
494{
495 Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1)))>::type > ret(expr);
496 return ret;
497}
498template <class Op, class T1, class T2>
500 -> Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))>::type >
501{
502 Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1),vecEval(0, expr.arg2)))>::type > ret(expr);
503 return ret;
504}
505template <class Op, class T1, class T2, class T3>
507 -> Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1),
508 vecEval(0, expr.arg2),
509 vecEval(0, expr.arg3)))>::type >
510{
511 Lattice<typename std::remove_const<decltype(expr.op.func(vecEval(0, expr.arg1),
512 vecEval(0, expr.arg2),
513 vecEval(0, expr.arg3)))>::type > ret(expr);
514 return ret;
515}
516#define EXPRESSION_CLOSURE(function) \
517 template<class Expression,typename std::enable_if<is_lattice_expr<Expression>::value,void>::type * = nullptr> \
518 auto function(Expression &expr) -> decltype(function(closure(expr))) \
519 { \
520 return function(closure(expr)); \
521 }
522
523
524#undef GRID_UNOP
525#undef GRID_BINOP
526#undef GRID_TRINOP
527
528#undef GRID_DEF_UNOP
529#undef GRID_DEF_BINOP
530#undef GRID_DEF_TRINOP
531
533
534#endif
accelerator_inline int acceleratorSIMTlane(int Nsimd)
#define accelerator_inline
static const int Even
static const int Odd
accelerator_inline void timesMinusI(Grid_simd2< S, V > &ret, const Grid_simd2< S, V > &in)
accelerator_inline void timesI(Grid_simd2< S, V > &ret, const Grid_simd2< S, V > &in)
accelerator_inline Grid_simd2< S, V > trace(const Grid_simd2< S, V > &arg)
accelerator_inline Grid_simd< S, V > cos(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > abs(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > asin(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > sin(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > Not(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > acos(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > exp(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > sqrt(const Grid_simd< S, V > &r)
accelerator_inline Grid_simd< S, V > log(const Grid_simd< S, V > &r)
accelerator_inline void GridFromExpression(GridBase *&grid, const T1 &lat)
Definition Lattice_ET.h:207
#define GRID_DEF_BINOP(op, name)
Definition Lattice_ET.h:438
auto closure(const LatticeUnaryExpression< Op, T1 > &expr) -> Lattice< typename std::remove_const< decltype(expr.op.func(vecEval(0, expr.arg1)))>::type >
Definition Lattice_ET.h:492
accelerator_inline sobj eval(const uint64_t ss, const sobj &arg)
Definition Lattice_ET.h:103
void ExpressionViewOpen(T1 &lat)
Definition Lattice_ET.h:278
#define GRID_DEF_UNOP(op, name)
Definition Lattice_ET.h:412
accelerator_inline sobj vecEval(const uint64_t ss, const sobj &arg)
Definition Lattice_ET.h:118
#define GridTrinOpClass(name, combination)
Definition Lattice_ET.h:388
#define GridUnopClass(name, ret)
Definition Lattice_ET.h:338
#define GRID_DEF_TRINOP(op, name)
Definition Lattice_ET.h:442
void ExpressionViewClose(T1 &lat)
Definition Lattice_ET.h:309
accelerator_inline vobj predicatedWhere(const iobj &predicate, const vobj &iftrue, const robj &iffalse)
Definition Lattice_ET.h:59
void CBFromExpression(int &cb, const T1 &lat)
Definition Lattice_ET.h:245
#define GridBinOpClass(name, combination)
Definition Lattice_ET.h:365
std::is_base_of< LatticeExpressionBase, T > is_lattice_expr
@ AcceleratorRead
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
#define T1
accelerator_inline iScalar< vtype > ProjectOnSpGroup(const iScalar< vtype > &r)
Definition Tensor_Ta.h:193
accelerator_inline iScalar< vtype > SpTa(const iScalar< vtype > &r)
Definition Tensor_Ta.h:69
accelerator_inline iScalar< vtype > ProjectOnGroup(const iScalar< vtype > &r)
Definition Tensor_Ta.h:124
accelerator_inline iScalar< vtype > Ta(const iScalar< vtype > &r)
Definition Tensor_Ta.h:45
accelerator_inline std::enable_if<!isGridTensor< T >::value, T >::type TensorRemove(T arg)
AcceleratorVector< __T,GRID_MAX_SIMD > ExtractBuffer
accelerator void extract(const vobj &vec, ExtractBuffer< sobj > &extracted)
accelerator void merge(vobj &vec, ExtractBuffer< sobj > &extracted)
accelerator_inline ComplexD transpose(ComplexD &rhs)
vobj vector_object
Lattice< T >::vector_object type
Definition Lattice_ET.h:93