30#ifndef HMC_RESOURCE_MANAGER_H
31#define HMC_RESOURCE_MANAGER_H
33#include <unordered_map>
38template <
class ImplementationPolicy>
44 typedef typename ImplementationPolicy::Field
Field;
47 std::unordered_map<std::string, GridModule>
Grids;
51 std::unique_ptr<CheckpointerBaseModule>
CP;
54 std::unique_ptr<MomentumFilterBase<typename ImplementationPolicy::Field> >
Filter;
61 std::multimap<int, std::unique_ptr<ActionBaseModule> >
ActionsList;
72 std::cout << i <<
" ";
73 std::cout << std::endl;
80 template <
class ReaderClass,
class vector_type = vComplex >
91 Read.push(
"Checkpointer");
93 read(Read,
"name", cp_type);
94 std::cout <<
"Registered types " << std::endl;
98 CP = CPfactory.create(cp_type, Read);
99 CP->print_parameters();
111 std::string obs_type;
112 read(Read,
"name", obs_type);
113 std::cout <<
"Registered types " << std::endl;
122 if(!Read.push(
"Actions")){
123 std::cout <<
"Actions not found" << std::endl;
127 if(!Read.push(
"Level")){
128 std::cout <<
"Level not found" << std::endl;
135 while(Read.push(
"Level"));
142 template <
class RepresentationPolicy>
147 (*it).second->acquireResource(
Grids[
"gauge"]);
148 Aset[(*it).first-1].push_back((*it).second->getPtr());
160 auto search =
Grids.find(s);
161 if (search !=
Grids.end()) {
162 std::cout <<
GridLogError <<
"Grid with name \"" << search->first
163 <<
"\" already present. Terminating\n";
166 Grids[s] = std::move(M);
167 std::cout <<
GridLogMessage <<
"::::::::::::::::::::::::::::::::::::::::" <<std::endl;
168 std::cout <<
GridLogMessage <<
"HMCResourceManager:" << std::endl;
169 std::cout <<
GridLogMessage <<
"Created grid set with name '" << s <<
"' and decomposition for the full cartesian " << std::endl;
170 Grids[s].show_full_decomposition();
171 std::cout <<
GridLogMessage <<
"::::::::::::::::::::::::::::::::::::::::" <<std::endl;
181 void AddFourDimGrid(
const std::string s,
const std::vector<int> simd_decomposition) {
188 Filter = std::unique_ptr<MomentumFilterBase<typename ImplementationPolicy::Field> >(MomFilter);
198 if (s.empty()) s =
Grids.begin()->first;
199 std::cout <<
GridLogDebug <<
"Getting cartesian grid from: " << s
201 return Grids[s].get_full();
205 if (s.empty()) s =
Grids.begin()->first;
206 std::cout <<
GridLogDebug <<
"Getting rb-cartesian grid from: " << s
208 return Grids[s].get_rb();
222 if (s.empty()) s =
Grids.begin()->first;
223 std::cout <<
GridLogDebug <<
"Adding RNG to grid: " << s << std::endl;
234 return RNGs.get_pRNG();
250 std::cout <<
GridLogError <<
"Error: no checkpointer defined"
270 template<
template<
class CPImplementationPolicy>
class CheckpointModule>
272 typedef CheckpointModule<ImplementationPolicy> CPM;
275 std::cout <<
GridLogDebug <<
"Loading Checkpointer " << CPM::Name << std::endl;
276 CP = std::unique_ptr<CheckpointerBaseModule>(
new CPM(Params_));
279 std::cout <<
GridLogError <<
"Checkpointer already loaded " << std::endl;
287 template<
template<
class CPImplementationPolicy,
class Metadata>
class CheckpointModule,
class Metadata>
289 typedef CheckpointModule<ImplementationPolicy, Metadata> CPM;
291 std::cout <<
GridLogDebug <<
"Loading Metadata Checkpointer " << CPM::Name << std::endl;
292 CP = std::unique_ptr<CheckpointerBaseModule>(
new CPM(Params_, M_));
295 std::cout <<
GridLogError <<
"Checkpointer already loaded " << std::endl;
306 template<
class Metadata>
307 void LoadScidacCheckpointer(
const CheckpointerParameters& Params_,
const Metadata& M_)
318 template<
class T,
class... Types>
320 ObservablesList.push_back(std::unique_ptr<T>(
new T(std::forward<Types>(Args)...)));
324 std::vector<HmcObservable<typename ImplementationPolicy::Field>* >
GetObservables(){
325 std::vector<HmcObservable<typename ImplementationPolicy::Field>* > out;
327 out.push_back(i->getPtr());
339 template <
class ReaderClass >
343 Read.readDefault(
"multiplier",m);
345 std::cout <<
"Level : " <<
multipliers.size() <<
" with multiplier : " << m << std::endl;
350 std::string action_type;
351 Read.readDefault(
"name", action_type);
353 ActionsList.emplace(m, ActionFactory.create(action_type, Read));
354 }
while (Read.nextElement(
"Action"));
std::vector< ActionLevel< GaugeField, R > > ActionSet
GridLogger GridLogError(1, "Error", GridLogColours, "RED")
GridLogger GridLogDebug(1, "Debug", GridLogColours, "PURPLE")
GridLogger GridLogMessage(1, "Message", GridLogColours, "NORMAL")
#define NAMESPACE_BEGIN(A)
GridRedBlackCartesian * GetRBCartesian(std::string s="")
ImplementationPolicy::Field MomentaField
void fill_ActionsLevel(ReaderClass &Read)
void GetActionSet(ActionSet< typename ImplementationPolicy::Field, RepresentationPolicy > &Aset)
ImplementationPolicy::Field Field
HMCModuleBase< BaseHmcCheckpointer< ImplementationPolicy > > CheckpointerBaseModule
std::multimap< int, std::unique_ptr< ActionBaseModule > > ActionsList
MomentumFilterBase< typename ImplementationPolicy::Field > * GetMomentumFilter(void)
void AddRNGs(std::string s="")
void SetRNGSeeds(RNGModuleParameters &Params)
std::vector< std::unique_ptr< ObservableBaseModule > > ObservablesList
std::unique_ptr< CheckpointerBaseModule > CP
GridParallelRNG & GetParallelRNG()
void AddGrid(const std::string s, GridModule &M)
std::unique_ptr< MomentumFilterBase< typename ImplementationPolicy::Field > > Filter
GridSerialRNG & GetSerialRNG()
void AddFourDimGrid(const std::string s)
void output_vector_string(const std::vector< std::string > &vs)
void LoadBinaryCheckpointer(const CheckpointerParameters &Params_)
void LoadCheckpointer(const CheckpointerParameters &Params_, const Metadata &M_)
std::vector< HmcObservable< typename ImplementationPolicy::Field > * > GetObservables()
void SetMomentumFilter(MomentumFilterBase< typename ImplementationPolicy::Field > *MomFilter)
std::vector< int > multipliers
ActionModuleBase< Action< typename ImplementationPolicy::Field >, GridModule > ActionBaseModule
std::unordered_map< std::string, GridModule > Grids
void AddObservable(Types &&... Args)
GridCartesian * GetCartesian(std::string s="")
void AddFourDimGrid(const std::string s, const std::vector< int > simd_decomposition)
void initialize(ReaderClass &Read)
void LoadCheckpointer(const CheckpointerParameters &Params_)
void LoadNerscCheckpointer(const CheckpointerParameters &Params_)
BaseHmcCheckpointer< ImplementationPolicy > * GetCheckPointer()
HMCModuleBase< HmcObservable< typename ImplementationPolicy::Field > > ObservableBaseModule
static HMC_ActionModuleFactory & getInstance(void)
static HMC_CPModuleFactory & getInstance(void)
static HMC_ObservablesModuleFactory & getInstance(void)