votca 2025.1-dev
Loading...
Searching...
No Matches
checkpointwriter.h
Go to the documentation of this file.
1/*
2 * Copyright 2009-2020 The VOTCA Development Team (http://www.votca.org)
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 *
15 */
16
17#pragma once
18#ifndef VOTCA_XTP_CHECKPOINTWRITER_H
19#define VOTCA_XTP_CHECKPOINTWRITER_H
20
21// Standard includes
22#include <map>
23#include <string>
24#include <type_traits>
25#include <typeinfo>
26#include <vector>
27
28#if defined(__clang__)
29#elif defined(__GNUC__)
30#pragma GCC diagnostic push
31#pragma GCC diagnostic ignored "-Wdeprecated-copy"
32#endif
33
34// Third party includes
35#include <H5Cpp.h>
36
37// VOTCA includes
38#include <votca/tools/linalg.h>
39
40// Local VOTCA includes
41#include "checkpoint_utils.h"
42#include "checkpointtable.h"
43#include "eigen.h"
44
45namespace votca {
46namespace xtp {
47
48using namespace checkpoint_utils;
49
51 public:
52 CheckpointWriter(const CptLoc& loc) : CheckpointWriter(loc, "/") {};
53
54 CheckpointWriter(const CptLoc& loc, const std::string& path)
55 : loc_(loc), path_(path) {};
56
57 // see the following links for details
58 // https://stackoverflow.com/a/8671617/1186564
59 template <typename T>
60 typename std::enable_if<!std::is_fundamental<T>::value>::type operator()(
61 const T& data, const std::string& name) const {
62 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
63 try {
64 WriteData(loc_, data, name);
65 } catch (H5::Exception&) {
66 std::stringstream message;
67 message << "Could not write " << name << " to " << loc_.getFileName()
68 << ":" << path_;
69
70 throw std::runtime_error(message.str());
71 }
72 }
73
74 // Use this overload if T is a fundamental type
75 // int, double, unsigned, etc, but not bool
76 template <typename T>
77 typename std::enable_if<std::is_fundamental<T>::value &&
78 !std::is_same<T, bool>::value>::type
79 operator()(const T& v, const std::string& name) const {
80 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
81
82 try {
83 WriteScalar(loc_, v, name);
84 } catch (H5::Exception&) {
85 std::stringstream message;
86 message << "Could not write " << name << " to " << loc_.getFileName()
87 << ":" << path_ << std::endl;
88
89 throw std::runtime_error(message.str());
90 }
91 }
92
93 void operator()(const bool& v, const std::string& name) const {
94 Index temp = static_cast<Index>(v);
95 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
96
97 try {
98 WriteScalar(loc_, temp, name);
99 } catch (H5::Exception&) {
100 std::stringstream message;
101 message << "Could not write " << name << " to " << loc_.getFileName()
102 << ":" << path_ << std::endl;
103
104 throw std::runtime_error(message.str());
105 }
106 }
107
108 void operator()(const std::string& v, const std::string& name) const {
109 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
110
111 try {
112 WriteScalar(loc_, v, name);
113 } catch (H5::Exception&) {
114 std::stringstream message;
115 message << "Could not write " << name << " to " << loc_.getFileName()
116 << ":" << path_ << std::endl;
117
118 throw std::runtime_error(message.str());
119 }
120 }
121
122 CheckpointWriter openChild(const std::string& childName) const {
123 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
124
125 try {
126 return CheckpointWriter(loc_.openGroup(childName),
127 path_ + "/" + childName);
128 } catch (H5::Exception&) {
129 try {
130 return CheckpointWriter(loc_.createGroup(childName),
131 path_ + "/" + childName);
132 } catch (H5::Exception&) {
133 std::stringstream message;
134 message << "Could not open or create" << loc_.getFileName() << ":/"
135 << path_ << "/" << childName << std::endl;
136
137 throw std::runtime_error(message.str());
138 }
139 }
140 }
141
142 template <typename T>
143 CptTable openTable(const std::string& name, std::size_t nRows,
144 bool compact = false) {
145 std::lock_guard<std::recursive_mutex> lock(checkpoint_utils::Hdf5Mutex());
146
147 CptTable table;
148 try {
149 table = CptTable(name, sizeof(typename T::data), loc_);
150 T::SetupCptTable(table);
151 } catch (H5::Exception&) {
152 try {
153 table = CptTable(name, sizeof(typename T::data), nRows);
154 T::SetupCptTable(table);
155 table.initialize(loc_, compact);
156 } catch (H5::Exception&) {
157 std::stringstream message;
158 message << "Could not open table " << name << " in "
159 << loc_.getFileName() << ":" << path_ << std::endl;
160 throw std::runtime_error(message.str());
161 }
162 }
163
164 return table;
165 }
166
167 private:
169 const std::string path_;
170 template <typename T>
171 void WriteScalar(const CptLoc& loc, const T& value,
172 const std::string& name) const {
173
174 hsize_t dims[1] = {1};
175 H5::DataSpace dp(1, dims);
176 const H5::DataType* dataType = InferDataType<T>::get();
177 H5::Attribute attr;
178 try {
179 attr = loc.createAttribute(name, *dataType, dp);
180 } catch (H5::AttributeIException&) {
181 attr = loc.openAttribute(name);
182 }
183 attr.write(*dataType, &value);
184 }
185
186 void WriteScalar(const CptLoc& loc, const std::string& value,
187 const std::string& name) const {
188 hsize_t dims[1] = {1};
189 H5::DataSpace dp(1, dims);
190 const H5::DataType* strType = InferDataType<std::string>::get();
191
192 H5::Attribute attr;
193
194 try {
195 attr = loc.createAttribute(name, *strType, dp);
196 } catch (H5::AttributeIException&) {
197 attr = loc.openAttribute(name);
198 }
199
200 const char* c_str_copy = value.c_str();
201 attr.write(*strType, &c_str_copy);
202 }
203
204 template <typename T>
205 void WriteData(const CptLoc& loc, const Eigen::MatrixBase<T>& matrix,
206 const std::string& name) const {
207
208 hsize_t matRows = hsize_t(matrix.rows());
209 hsize_t matCols = hsize_t(matrix.cols());
210
211 hsize_t dims[2] = {matRows, matCols}; // eigen vectors are n,1 matrices
212
213 if (dims[1] == 0) {
214 dims[1] = 1;
215 }
216
217 H5::DataSpace dp(2, dims);
218 const H5::DataType* dataType = InferDataType<typename T::Scalar>::get();
219 H5::DataSet dataset;
220 try {
221 dataset = loc.createDataSet(name.c_str(), *dataType, dp);
222 } catch (H5::GroupIException&) {
223 dataset = loc.openDataSet(name.c_str());
224 }
225
226 hsize_t matColSize = matrix.derived().outerStride();
227
228 hsize_t fileRows = matCols;
229
230 hsize_t fStride[2] = {1, fileRows};
231 hsize_t fCount[2] = {1, 1};
232 hsize_t fBlock[2] = {1, fileRows};
233
234 hsize_t mStride[2] = {matColSize, 1};
235 hsize_t mCount[2] = {1, 1};
236 hsize_t mBlock[2] = {matCols, 1};
237
238 hsize_t mDim[2] = {matCols, matColSize};
239 H5::DataSpace mspace(2, mDim);
240
241 for (hsize_t i = 0; i < matRows; i++) {
242 hsize_t fStart[2] = {i, 0};
243 hsize_t mStart[2] = {0, i};
244 dp.selectHyperslab(H5S_SELECT_SET, fCount, fStart, fStride, fBlock);
245 mspace.selectHyperslab(H5S_SELECT_SET, mCount, mStart, mStride, mBlock);
246 dataset.write(matrix.derived().data(), *dataType, mspace, dp);
247 }
248 }
249
250 template <typename T>
251 typename std::enable_if<std::is_fundamental<T>::value>::type WriteData(
252 const CptLoc& loc, const std::vector<T> v,
253 const std::string& name) const {
254 hsize_t dims[2] = {(hsize_t)v.size(), 1};
255
256 const H5::DataType* dataType = InferDataType<T>::get();
257 H5::DataSet dataset;
258 H5::DataSpace dp(2, dims);
259 try {
260 dataset = loc.createDataSet(name.c_str(), *dataType, dp);
261 } catch (H5::GroupIException&) {
262 dataset = loc.openDataSet(name.c_str());
263 }
264 dataset.write(v.data(), *dataType);
265 }
266
267 void WriteData(const CptLoc& loc, const std::vector<std::string>& v,
268 const std::string& name) const {
269
270 hsize_t dims[1] = {(hsize_t)v.size()};
271
272 std::vector<const char*> c_str_copy;
273 c_str_copy.reserve(v.size());
274 for (const std::string& s : v) {
275 c_str_copy.push_back(s.c_str());
276 }
277 const H5::DataType* dataType = InferDataType<std::string>::get();
278 H5::DataSet dataset;
279 H5::DataSpace dp(1, dims);
280 try {
281 dataset = loc.createDataSet(name.c_str(), *dataType, dp);
282 } catch (H5::GroupIException&) {
283 dataset = loc.openDataSet(name.c_str());
284 }
285 dataset.write(c_str_copy.data(), *dataType);
286 }
287
288 void WriteData(const CptLoc& loc, const std::vector<Eigen::Vector3d>& v,
289 const std::string& name) const {
290
291 size_t c = 0;
292 std::string r;
293 CptLoc parent;
294 try {
295 parent = loc.createGroup(name);
296 } catch (H5::GroupIException&) {
297 parent = loc.openGroup(name);
298 }
299 for (auto const& x : v) {
300 r = std::to_string(c);
301 WriteData(parent, x, "ind" + r);
302 ++c;
303 }
304 }
305
306 void WriteData(const CptLoc& loc, const tools::EigenSystem& sys,
307 const std::string& name) const {
308
309 CptLoc parent;
310 try {
311 parent = loc.createGroup(name);
312 } catch (H5::GroupIException&) {
313 parent = loc.openGroup(name);
314 }
315
316 WriteData(parent, sys.eigenvalues(), "eigenvalues");
317 WriteData(parent, sys.eigenvectors(), "eigenvectors");
318 WriteData(parent, sys.eigenvectors2(), "eigenvectors2");
319 WriteScalar(parent, Index(sys.info()), "info");
320 }
321
322 template <typename T1, typename T2>
323 void WriteData(const CptLoc& loc, const std::map<T1, std::vector<T2>> map,
324 const std::string& name) const {
325
326 size_t c = 0;
327 std::string r;
328 // Iterate over the map and write map as a number of vectors with T1 as
329 // index
330 for (auto const& x : map) {
331 r = std::to_string(c);
332 CptLoc tempGr;
333 try {
334 tempGr = loc.createGroup(name);
335 } catch (H5::GroupIException&) {
336 tempGr = loc.openGroup(name);
337 }
338 WriteData(tempGr, x.second, "index" + r);
339 ++c;
340 }
341 }
342};
343} // namespace xtp
344} // namespace votca
345
346#if defined(__clang__)
347#elif defined(__GNUC__)
348#pragma GCC diagnostic pop
349#endif
350
351#endif // VOTCA_XTP_CHECKPOINTWRITER_H
const Eigen::MatrixXd & eigenvectors2() const
Definition eigensystem.h:36
Eigen::ComputationInfo info() const
Definition eigensystem.h:39
const Eigen::VectorXd & eigenvalues() const
Definition eigensystem.h:30
const Eigen::MatrixXd & eigenvectors() const
Definition eigensystem.h:33
CheckpointWriter(const CptLoc &loc, const std::string &path)
std::enable_if< std::is_fundamental< T >::value &&!std::is_same< T, bool >::value >::type operator()(const T &v, const std::string &name) const
CheckpointWriter(const CptLoc &loc)
void WriteScalar(const CptLoc &loc, const std::string &value, const std::string &name) const
void WriteData(const CptLoc &loc, const Eigen::MatrixBase< T > &matrix, const std::string &name) const
CheckpointWriter openChild(const std::string &childName) const
void operator()(const std::string &v, const std::string &name) const
void WriteScalar(const CptLoc &loc, const T &value, const std::string &name) const
void operator()(const bool &v, const std::string &name) const
std::enable_if<!std::is_fundamental< T >::value >::type operator()(const T &data, const std::string &name) const
void WriteData(const CptLoc &loc, const std::vector< Eigen::Vector3d > &v, const std::string &name) const
void WriteData(const CptLoc &loc, const std::map< T1, std::vector< T2 > > map, const std::string &name) const
void WriteData(const CptLoc &loc, const tools::EigenSystem &sys, const std::string &name) const
std::enable_if< std::is_fundamental< T >::value >::type WriteData(const CptLoc &loc, const std::vector< T > v, const std::string &name) const
CptTable openTable(const std::string &name, std::size_t nRows, bool compact=false)
void WriteData(const CptLoc &loc, const std::vector< std::string > &v, const std::string &name) const
void initialize(const CptLoc &loc, bool compact)
std::recursive_mutex & Hdf5Mutex()
H5::Group CptLoc
Provides a means for comparing floating point numbers.
Definition basebead.h:33
Eigen::Index Index
Definition types.h:26