Grid 0.7.0
HMCResourceManager.h
Go to the documentation of this file.
1/*************************************************************************************
2
3Grid physics library, www.github.com/paboyle/Grid
4
5Source file: ./lib/qcd/hmc/GenericHmcRunner.h
6
7Copyright (C) 2015
8Copyright (C) 2016
9
10Author: Guido Cossu <guido.cossu@ed.ac.uk>
11
12This program is free software; you can redistribute it and/or modify
13it under the terms of the GNU General Public License as published by
14the Free Software Foundation; either version 2 of the License, or
15(at your option) any later version.
16
17This program is distributed in the hope that it will be useful,
18but WITHOUT ANY WARRANTY; without even the implied warranty of
19MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
20GNU General Public License for more details.
21
22You should have received a copy of the GNU General Public License along
23with this program; if not, write to the Free Software Foundation, Inc.,
2451 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
25
26 See the full license in the file "LICENSE" in the top level distribution
27 directory
28*************************************************************************************/
29 /* END LEGAL */
30#ifndef HMC_RESOURCE_MANAGER_H
31#define HMC_RESOURCE_MANAGER_H
32
33#include <unordered_map>
34
36
37// HMC Resource manager
38template <class ImplementationPolicy>
43 typedef typename ImplementationPolicy::Field MomentaField;
44 typedef typename ImplementationPolicy::Field Field;
45
46 // Named storage for grid pairs (std + red-black)
47 std::unordered_map<std::string, GridModule> Grids;
49
50 // SmearingModule<ImplementationPolicy> Smearing;
51 std::unique_ptr<CheckpointerBaseModule> CP;
52
53 // Momentum filter
54 std::unique_ptr<MomentumFilterBase<typename ImplementationPolicy::Field> > Filter;
55
56 // A vector of HmcObservable modules
57 std::vector<std::unique_ptr<ObservableBaseModule> > ObservablesList;
58
59
60 // A vector of HmcObservable modules
61 std::multimap<int, std::unique_ptr<ActionBaseModule> > ActionsList;
62 std::vector<int> multipliers;
63
67
68 // NOTE: operator << is not overloaded for std::vector<string>
69 // so this function is necessary
70 void output_vector_string(const std::vector<std::string> &vs){
71 for (auto &i: vs)
72 std::cout << i << " ";
73 std::cout << std::endl;
74 }
75
76
77public:
79
80 template <class ReaderClass, class vector_type = vComplex >
81 void initialize(ReaderClass &Read){
82 // assumes we are starting from the main node
83
84 // Geometry
85 GridModuleParameters GridPar(Read);
86 GridFourDimModule<vector_type> GridMod( GridPar) ;
87 AddGrid("gauge", GridMod);
88
89 // Checkpointer
91 Read.push("Checkpointer");
92 std::string cp_type;
93 read(Read,"name", cp_type);
94 std::cout << "Registered types " << std::endl;
95 output_vector_string(CPfactory.getBuilderList());
96
97
98 CP = CPfactory.create(cp_type, Read);
99 CP->print_parameters();
100 Read.pop();
101 have_CheckPointer = true;
102
103 RNGModuleParameters RNGpar(Read);
104 SetRNGSeeds(RNGpar);
105
106
107 // Observables
109 Read.push(observable_string);// here must check if existing...
110 do {
111 std::string obs_type;
112 read(Read,"name", obs_type);
113 std::cout << "Registered types " << std::endl;
114 output_vector_string(ObsFactory.getBuilderList() );
115
116 ObservablesList.emplace_back(ObsFactory.create(obs_type, Read));
117 ObservablesList[ObservablesList.size() - 1]->print_parameters();
118 } while (Read.nextElement(observable_string));
119 Read.pop();
120
121 // Loop on levels
122 if(!Read.push("Actions")){
123 std::cout << "Actions not found" << std::endl;
124 exit(1);
125 }
126
127 if(!Read.push("Level")){// push must check if the node exist
128 std::cout << "Level not found" << std::endl;
129 exit(1);
130 }
131 do
132 {
133 fill_ActionsLevel(Read);
134 }
135 while(Read.push("Level"));
136
137 Read.pop();
138 }
139
140
141
142 template <class RepresentationPolicy>
144 Aset.resize(multipliers.size());
145
146 for(auto it = ActionsList.begin(); it != ActionsList.end(); it++){
147 (*it).second->acquireResource(Grids["gauge"]);
148 Aset[(*it).first-1].push_back((*it).second->getPtr());
149 }
150 }
151
152
153
155 // Grids
157
158 void AddGrid(const std::string s, GridModule& M) {
159 // Check for name clashes
160 auto search = Grids.find(s);
161 if (search != Grids.end()) {
162 std::cout << GridLogError << "Grid with name \"" << search->first
163 << "\" already present. Terminating\n";
164 exit(1);
165 }
166 Grids[s] = std::move(M);
167 std::cout << GridLogMessage << "::::::::::::::::::::::::::::::::::::::::" <<std::endl;
168 std::cout << GridLogMessage << "HMCResourceManager:" << std::endl;
169 std::cout << GridLogMessage << "Created grid set with name '" << s << "' and decomposition for the full cartesian " << std::endl;
170 Grids[s].show_full_decomposition();
171 std::cout << GridLogMessage << "::::::::::::::::::::::::::::::::::::::::" <<std::endl;
172 }
173
174 // Add a named grid set, 4d shortcut
175 void AddFourDimGrid(const std::string s) {
177 AddGrid(s, Mod);
178 }
179
180 // Add a named grid set, 4d shortcut + tweak simd lanes
181 void AddFourDimGrid(const std::string s, const std::vector<int> simd_decomposition) {
182 GridFourDimModule<vComplex> Mod(simd_decomposition);
183 AddGrid(s, Mod);
184 }
185
187 assert(have_Filter==false);
188 Filter = std::unique_ptr<MomentumFilterBase<typename ImplementationPolicy::Field> >(MomFilter);
189 have_Filter = true;
190 }
196
197 GridCartesian* GetCartesian(std::string s = "") {
198 if (s.empty()) s = Grids.begin()->first;
199 std::cout << GridLogDebug << "Getting cartesian grid from: " << s
200 << std::endl;
201 return Grids[s].get_full();
202 }
203
205 if (s.empty()) s = Grids.begin()->first;
206 std::cout << GridLogDebug << "Getting rb-cartesian grid from: " << s
207 << std::endl;
208 return Grids[s].get_rb();
209 }
210
212 // Random number generators
214
215 //Return true if the RNG objects have been instantiated
216 bool haveRNGs() const{ return have_RNG; }
217
218 void AddRNGs(std::string s = "") {
219 // Couple the RNGs to the GridModule tagged by s
220 // the default is the first grid registered
221 assert(Grids.size() > 0 && !have_RNG);
222 if (s.empty()) s = Grids.begin()->first;
223 std::cout << GridLogDebug << "Adding RNG to grid: " << s << std::endl;
224 RNGs.set_pRNG(new GridParallelRNG(GetCartesian(s)));
225 have_RNG = true;
226 }
227
228 void SetRNGSeeds(RNGModuleParameters& Params) { RNGs.set_RNGSeeds(Params); }
229
230 GridSerialRNG& GetSerialRNG() { return RNGs.get_sRNG(); }
231
233 assert(have_RNG);
234 return RNGs.get_pRNG();
235 }
236
238 assert(have_RNG);
239 RNGs.seed();
240 }
241
243 // Checkpointers
245
248 return CP->getPtr();
249 else {
250 std::cout << GridLogError << "Error: no checkpointer defined"
251 << std::endl;
252 exit(1);
253 }
254 }
255
256 /* Load a checkpointer with no attached metadata.
257 HMCResourceManager must have a single checkpointer loaded.
258 Starting a HMC run without a checkpointer loaded,
259 or attempting to load multiple checkpointers,
260 will cause the application to exit with an error.
261
262 Instantiations for built-in checkpointers are below;
263 see LoadBinaryCheckpointer, LoadNerscCheckpointer, LoadILDGCheckpointer,
264 and LoadScidacCheckpointer.
265 To load your own checkpointer,
266 use a line similar to
267
268 TheHMC.Resources.template LoadCheckpointer<MyCheckPointer>(params);
269 */
270 template<template<class CPImplementationPolicy> class CheckpointModule>
272 typedef CheckpointModule<ImplementationPolicy> CPM;
274 {
275 std::cout << GridLogDebug << "Loading Checkpointer " << CPM::Name << std::endl;
276 CP = std::unique_ptr<CheckpointerBaseModule>(new CPM(Params_));
277 have_CheckPointer = true;
278 } else {
279 std::cout << GridLogError << "Checkpointer already loaded " << std::endl;
280 exit(1);
281 }
282 }
283
284 /* Load a checkpointer with attached metadata;
285 see the definition of LoadCheckpointer(const CheckpointerParameters& Params_)
286 for further details. */
287 template<template<class CPImplementationPolicy, class Metadata> class CheckpointModule, class Metadata>
288 void LoadCheckpointer(const CheckpointerParameters& Params_, const Metadata& M_) {
289 typedef CheckpointModule<ImplementationPolicy, Metadata> CPM;
290 if (!have_CheckPointer) {
291 std::cout << GridLogDebug << "Loading Metadata Checkpointer " << CPM::Name << std::endl;
292 CP = std::unique_ptr<CheckpointerBaseModule>( new CPM(Params_, M_));
293 have_CheckPointer = true;
294 } else {
295 std::cout << GridLogError << "Checkpointer already loaded " << std::endl;
296 exit(1);
297 }
298 }
299
300 /* Checkpoint loaders for built-in checkpointers. */
303#ifdef HAVE_LIME
304 void LoadILDGCheckpointer (const CheckpointerParameters& Params_) { LoadCheckpointer<ILDGCPModule>(Params_); }
305
306 template<class Metadata>
307 void LoadScidacCheckpointer(const CheckpointerParameters& Params_, const Metadata& M_)
308 {
310 }
311
312#endif
313
315 // Observables
317
318 template<class T, class... Types>
319 void AddObservable(Types&&... Args){
320 ObservablesList.push_back(std::unique_ptr<T>(new T(std::forward<Types>(Args)...)));
321 ObservablesList.back()->print_parameters();
322 }
323
324 std::vector<HmcObservable<typename ImplementationPolicy::Field>* > GetObservables(){
325 std::vector<HmcObservable<typename ImplementationPolicy::Field>* > out;
326 for (auto &i : ObservablesList){
327 out.push_back(i->getPtr());
328 }
329
330 // Add the checkpointer to the observables
331 out.push_back(GetCheckPointer());
332 return out;
333 }
334
335
336
337private:
338 // this private
339 template <class ReaderClass >
340 void fill_ActionsLevel(ReaderClass &Read){
341 // Actions set
342 int m;
343 Read.readDefault("multiplier",m);
344 multipliers.push_back(m);
345 std::cout << "Level : " << multipliers.size() << " with multiplier : " << m << std::endl;
346 // here gauge
347 Read.push("Action");
348 do{
350 std::string action_type;
351 Read.readDefault("name", action_type);
352 output_vector_string(ActionFactory.getBuilderList() );
353 ActionsList.emplace(m, ActionFactory.create(action_type, Read));
354 } while (Read.nextElement("Action"));
355 ActionsList.find(m)->second->print_parameters();
356 Read.pop();
357
358 }
359
360
361
362};
363
365
366#endif // HMC_RESOURCE_MANAGER_H
std::vector< ActionLevel< GaugeField, R > > ActionSet
Definition ActionSet.h:108
GridLogger GridLogError(1, "Error", GridLogColours, "RED")
GridLogger GridLogDebug(1, "Debug", GridLogColours, "PURPLE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
char observable_string[]
GridRedBlackCartesian * GetRBCartesian(std::string s="")
ImplementationPolicy::Field MomentaField
void fill_ActionsLevel(ReaderClass &Read)
void GetActionSet(ActionSet< typename ImplementationPolicy::Field, RepresentationPolicy > &Aset)
ImplementationPolicy::Field Field
HMCModuleBase< BaseHmcCheckpointer< ImplementationPolicy > > CheckpointerBaseModule
std::multimap< int, std::unique_ptr< ActionBaseModule > > ActionsList
MomentumFilterBase< typename ImplementationPolicy::Field > * GetMomentumFilter(void)
void AddRNGs(std::string s="")
void SetRNGSeeds(RNGModuleParameters &Params)
std::vector< std::unique_ptr< ObservableBaseModule > > ObservablesList
std::unique_ptr< CheckpointerBaseModule > CP
GridParallelRNG & GetParallelRNG()
void AddGrid(const std::string s, GridModule &M)
std::unique_ptr< MomentumFilterBase< typename ImplementationPolicy::Field > > Filter
GridSerialRNG & GetSerialRNG()
void AddFourDimGrid(const std::string s)
void output_vector_string(const std::vector< std::string > &vs)
void LoadBinaryCheckpointer(const CheckpointerParameters &Params_)
void LoadCheckpointer(const CheckpointerParameters &Params_, const Metadata &M_)
std::vector< HmcObservable< typename ImplementationPolicy::Field > * > GetObservables()
void SetMomentumFilter(MomentumFilterBase< typename ImplementationPolicy::Field > *MomFilter)
std::vector< int > multipliers
ActionModuleBase< Action< typename ImplementationPolicy::Field >, GridModule > ActionBaseModule
std::unordered_map< std::string, GridModule > Grids
void AddObservable(Types &&... Args)
GridCartesian * GetCartesian(std::string s="")
void AddFourDimGrid(const std::string s, const std::vector< int > simd_decomposition)
void initialize(ReaderClass &Read)
void LoadCheckpointer(const CheckpointerParameters &Params_)
void LoadNerscCheckpointer(const CheckpointerParameters &Params_)
BaseHmcCheckpointer< ImplementationPolicy > * GetCheckPointer()
HMCModuleBase< HmcObservable< typename ImplementationPolicy::Field > > ObservableBaseModule
static HMC_ActionModuleFactory & getInstance(void)
static HMC_CPModuleFactory & getInstance(void)
static HMC_ObservablesModuleFactory & getInstance(void)