Grid 0.7.0
Tensor_traits.h
Go to the documentation of this file.
1 /*************************************************************************************
2 Grid physics library, www.github.com/paboyle/Grid
3 Source file: ./lib/tensors/Tensor_traits.h
4 Copyright (C) 2015
5Author: Azusa Yamaguchi <ayamaguc@staffmail.ed.ac.uk>
6Author: Peter Boyle <paboyle@ph.ed.ac.uk>
7Author: Christopher Kelly <ckelly@phys.columbia.edu>
8Author: Michael Marshall <michael.marshall@ed.ac.au>
9Author: Christoph Lehner <christoph@lhnr.de>
10 This program is free software; you can redistribute it and/or modify
11 it under the terms of the GNU General Public License as published by
12 the Free Software Foundation; either version 2 of the License, or
13 (at your option) any later version.
14 This program is distributed in the hope that it will be useful,
15 but WITHOUT ANY WARRANTY; without even the implied warranty of
16 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
17 GNU General Public License for more details.
18 You should have received a copy of the GNU General Public License along
19 with this program; if not, write to the Free Software Foundation, Inc.,
20 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 See the full license in the file "LICENSE" in the top level distribution directory
22 *************************************************************************************/
23 /* END LEGAL */
24#pragma once
25
26#include <type_traits>
27
29
30 // Forward declarations
31 template<class T> class iScalar;
32 template<class T, int N> class iVector;
33 template<class T, int N> class iMatrix;
34
35 // These are the Grid tensors
36 template<typename T> struct isGridTensor : public std::false_type { static constexpr bool notvalue = true; };
37 template<class T> struct isGridTensor<iScalar<T> > : public std::true_type { static constexpr bool notvalue = false; };
38 template<class T, int N> struct isGridTensor<iVector<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
39 template<class T, int N> struct isGridTensor<iMatrix<T, N> >: public std::true_type { static constexpr bool notvalue = false; };
40
41 template <typename T> using IfGridTensor = Invoke<std::enable_if<isGridTensor<T>::value, int> >;
43
44 // Traits to identify scalars
45 template<typename T> struct isGridScalar : public std::false_type { static constexpr bool notvalue = true; };
46 template<class T> struct isGridScalar<iScalar<T>> : public std::true_type { static constexpr bool notvalue = false; };
47
48
49 // Traits to identify fundamental data types
50 template<typename T> struct isGridFundamental : public std::false_type { static constexpr bool notvalue = true; };
51 template<> struct isGridFundamental<vComplexF> : public std::true_type { static constexpr bool notvalue = false; };
52 template<> struct isGridFundamental<vComplexD> : public std::true_type { static constexpr bool notvalue = false; };
53 template<> struct isGridFundamental<vRealF> : public std::true_type { static constexpr bool notvalue = false; };
54 template<> struct isGridFundamental<vRealD> : public std::true_type { static constexpr bool notvalue = false; };
55 template<> struct isGridFundamental<ComplexF> : public std::true_type { static constexpr bool notvalue = false; };
56 template<> struct isGridFundamental<ComplexD> : public std::true_type { static constexpr bool notvalue = false; };
57 template<> struct isGridFundamental<RealF> : public std::true_type { static constexpr bool notvalue = false; };
58 template<> struct isGridFundamental<RealD> : public std::true_type { static constexpr bool notvalue = false; };
59 template<> struct isGridFundamental<vComplexD2> : public std::true_type { static constexpr bool notvalue = false; };
60 template<> struct isGridFundamental<vRealD2> : public std::true_type { static constexpr bool notvalue = false; };
61
62
64// Want to recurse: GridTypeMapper<Matrix<vComplexD> >::scalar_type == ComplexD.
65// Use of a helper class like this allows us to template specialise and "dress"
66// other classes such as RealD == double, ComplexD == std::complex<double> with these
67// traits.
68//
69// It is possible that we could do this more elegantly if I introduced a
70// queryable trait in iScalar, iMatrix and iVector and used the query on vtype in
71// place of the type mapper?
72//
73// Not sure how to do this, but probably could be done with a research effort
74// to study C++11's type_traits.h file. (std::enable_if<isGridTensorType<vtype> >)
75//
77
78 // This saves repeating common properties for supported Grid Scalar types
79 // TensorLevel How many nested grid tensors
80 // Rank Rank of the grid tensor
81 // count Total number of elements, i.e. product of dimensions
82 // Dimension(dim) Size of dimension dim
84 static constexpr int TensorLevel = 0;
85 static constexpr int Rank = 0;
86 static constexpr std::size_t count = 1;
87 static constexpr int Dimension(int dim) { return 0; }
88 };
89
91// Recursion stops with these template specialisations
93
94 template<typename T> struct GridTypeMapper {};
95
148
149#if defined(GRID_CUDA) || defined(GRID_HIP)
150 template<> struct GridTypeMapper<std::complex<float> > : public GridTypeMapper_Base {
151 typedef std::complex<float> scalar_type;
152 typedef std::complex<double> scalar_typeD;
153 typedef scalar_type vector_type;
154 typedef scalar_typeD vector_typeD;
155 typedef scalar_type tensor_reduced;
156 typedef scalar_type scalar_object;
157 typedef scalar_typeD scalar_objectD;
158 typedef scalar_type Complexified;
159 typedef RealF Realified;
160 typedef scalar_typeD DoublePrecision;
161 typedef scalar_typeD DoublePrecision2;
162 };
163 template<> struct GridTypeMapper<std::complex<double> > : public GridTypeMapper_Base {
164 typedef std::complex<double> scalar_type;
165 typedef std::complex<double> scalar_typeD;
166 typedef scalar_type vector_type;
167 typedef scalar_typeD vector_typeD;
168 typedef scalar_type tensor_reduced;
169 typedef scalar_type scalar_object;
170 typedef scalar_typeD scalar_objectD;
171 typedef scalar_type Complexified;
172 typedef RealD Realified;
173 typedef scalar_typeD DoublePrecision;
174 typedef scalar_typeD DoublePrecision2;
175 };
176#endif
177
191
231 template<> struct GridTypeMapper<vRealH> : public GridTypeMapper_Base {
232 // Fixme this is incomplete until Grid supports fp16 or bfp16 arithmetic types
244 };
245 template<> struct GridTypeMapper<vComplexH> : public GridTypeMapper_Base {
246 // Fixme this is incomplete until Grid supports fp16 or bfp16 arithmetic types
258 };
311
312#define GridTypeMapper_RepeatedTypes \
313 using BaseTraits = GridTypeMapper<T>; \
314 using scalar_type = typename BaseTraits::scalar_type; \
315 using vector_type = typename BaseTraits::vector_type; \
316 using scalar_typeD = typename BaseTraits::scalar_typeD; \
317 using vector_typeD = typename BaseTraits::vector_typeD; \
318 static constexpr int TensorLevel = BaseTraits::TensorLevel + 1
319
334
349
364
365 // Match the index
366 template<typename T,int Level> struct matchGridTensorIndex {
367 static const bool value = (Level==T::TensorLevel);
368 static const bool notvalue = (Level!=T::TensorLevel);
369 };
370 // What is the vtype
371 template<typename T> struct isComplex {
372 static const bool value = false;
373 };
374 template<> struct isComplex<ComplexF> {
375 static const bool value = true;
376 };
377 template<> struct isComplex<ComplexD> {
378 static const bool value = true;
379 };
380
381 //Get the SIMD vector type from a Grid tensor or Lattice<Tensor>
382 template<typename T>
384 typedef T type;
385 };
386
387 //Query whether a tensor or Lattice<Tensor> is SIMD vector or scalar
388 template<typename T, typename V=void> struct isSIMDvectorized : public std::false_type {};
389 template<typename U> struct isSIMDvectorized<U, typename std::enable_if< !std::is_same<
390 typename GridTypeMapper<typename getVectorType<U>::type>::scalar_type,
391 typename GridTypeMapper<typename getVectorType<U>::type>::vector_type>::value, void>::type>
392 : public std::true_type {};
393
394 //Get the precision of a Lattice, tensor or scalar type in units of sizeof(float)
395 template<typename T>
397 public:
398 //get the vector_obj (i.e. a grid Tensor) if its a Lattice<vobj>, do nothing otherwise (i.e. if fundamental or grid Tensor)
400 typedef typename GridTypeMapper<vector_obj>::scalar_type scalar_type; //get the associated scalar type. Works on fundamental and tensor types
401 typedef typename GridTypeMapper<scalar_type>::Realified real_scalar_type; //remove any std::complex wrapper, should get us to the fundamental type
402
403 enum { value = sizeof(real_scalar_type)/sizeof(float) };
404 };
406
407
408
Grid_simd2< complex< double >, vComplexD > vComplexD2
Grid_simd2< double, vRealD > vRealD2
Grid_simd< complex< float >, SIMD_Ftype > vComplexF
Grid_simd< uint16_t, SIMD_Htype > vRealH
Grid_simd< complex< uint16_t >, SIMD_Htype > vComplexH
Grid_simd< float, SIMD_Ftype > vRealF
Grid_simd< complex< double >, SIMD_Dtype > vComplexD
typename T::type Invoke
Grid_simd< Integer, SIMD_Itype > vInteger
Grid_simd< double, SIMD_Dtype > vRealD
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
uint32_t Integer
Definition Simd.h:58
std::complex< T > complex
Definition Simd.h:82
std::complex< RealF > ComplexF
Definition Simd.h:78
float RealF
Definition Simd.h:60
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
Invoke< std::enable_if<!isGridTensor< T >::value, int > > IfNotGridTensor
Invoke< std::enable_if< isGridTensor< T >::value, int > > IfGridTensor
static INTERNAL_PRECISION U
Definition Zolotarev.cc:230
getVectorType< fobj >::type vector_obj
GridTypeMapper< vector_obj >::scalar_type scalar_type
GridTypeMapper< scalar_type >::Realified real_scalar_type
iMatrix< typename BaseTraits::scalar_object, N > scalar_object
iMatrix< typename BaseTraits::DoublePrecision, N > DoublePrecision
iMatrix< typename BaseTraits::Realified, N > Realified
iMatrix< typename BaseTraits::scalar_objectD, N > scalar_objectD
static constexpr int Dimension(int dim)
static constexpr std::size_t count
iMatrix< typename BaseTraits::DoublePrecision2, N > DoublePrecision2
iMatrix< typename BaseTraits::Complexified, N > Complexified
iScalar< typename BaseTraits::tensor_reduced > tensor_reduced
iScalar< typename BaseTraits::DoublePrecision > DoublePrecision
iScalar< typename BaseTraits::scalar_object > scalar_object
iScalar< typename BaseTraits::DoublePrecision2 > DoublePrecision2
iScalar< typename BaseTraits::scalar_objectD > scalar_objectD
static constexpr std::size_t count
iScalar< typename BaseTraits::tensor_reduced > tensor_reduced
iScalar< typename BaseTraits::Complexified > Complexified
iScalar< typename BaseTraits::Realified > Realified
static constexpr int Dimension(int dim)
iVector< typename BaseTraits::Complexified, N > Complexified
iVector< typename BaseTraits::scalar_objectD, N > scalar_objectD
iVector< typename BaseTraits::DoublePrecision2, N > DoublePrecision2
iScalar< typename BaseTraits::tensor_reduced > tensor_reduced
iVector< typename BaseTraits::scalar_object, N > scalar_object
static constexpr int Dimension(int dim)
iVector< typename BaseTraits::DoublePrecision, N > DoublePrecision
static constexpr std::size_t count
iVector< typename BaseTraits::Realified, N > Realified
static constexpr std::size_t count
static constexpr int Rank
static constexpr int TensorLevel
static constexpr int Dimension(int dim)
static const bool value
static const bool value
static const bool value
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static constexpr bool notvalue
static const bool notvalue
static const bool value