mlpack  3.0.2
serialization.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_SERIALIZATION_HPP
13 #define MLPACK_TESTS_SERIALIZATION_HPP
14 
15 #include <boost/serialization/serialization.hpp>
16 #include <boost/archive/xml_iarchive.hpp>
17 #include <boost/archive/xml_oarchive.hpp>
18 #include <boost/archive/text_iarchive.hpp>
19 #include <boost/archive/text_oarchive.hpp>
20 #include <boost/archive/binary_iarchive.hpp>
21 #include <boost/archive/binary_oarchive.hpp>
22 #include <mlpack/core.hpp>
23 
24 #include <boost/test/unit_test.hpp>
25 #include "test_tools.hpp"
26 
27 namespace mlpack {
28 
29 // Test function for loading and saving Armadillo objects.
30 template<typename CubeType,
31  typename IArchiveType,
32  typename OArchiveType>
33 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
34 {
35  // First save it.
36  // Use type_info name to get unique file name for serialization test files.
37  std::string fileName = FilterFileName(typeid(IArchiveType).name());
38  std::ofstream ofs(fileName, std::ios::binary);
39  bool success = true;
40 
41  {
42  OArchiveType o(ofs);
43 
44  try
45  {
46  o << BOOST_SERIALIZATION_NVP(x);
47  }
48  catch (boost::archive::archive_exception& e)
49  {
50  success = false;
51  }
52  }
53 
54  BOOST_REQUIRE_EQUAL(success, true);
55  ofs.close();
56 
57  // Now load it.
58  arma::Cube<CubeType> orig(x);
59  success = true;
60  std::ifstream ifs(fileName, std::ios::binary);
61 
62  {
63  IArchiveType i(ifs);
64 
65  try
66  {
67  i >> BOOST_SERIALIZATION_NVP(x);
68  }
69  catch (boost::archive::archive_exception& e)
70  {
71  success = false;
72  }
73  }
74  ifs.close();
75 
76  remove(fileName.c_str());
77 
78  BOOST_REQUIRE_EQUAL(success, true);
79 
80  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
81  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
82  BOOST_REQUIRE_EQUAL(x.n_elem_slice, orig.n_elem_slice);
83  BOOST_REQUIRE_EQUAL(x.n_slices, orig.n_slices);
84  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
85 
86  for (size_t slice = 0; slice != x.n_slices; ++slice)
87  {
88  const auto& origSlice = orig.slice(slice);
89  const auto& xSlice = x.slice(slice);
90  for (size_t i = 0; i < x.n_cols; ++i)
91  {
92  for (size_t j = 0; j < x.n_rows; ++j)
93  {
94  if (double(origSlice(j, i)) == 0.0)
95  BOOST_REQUIRE_SMALL(double(xSlice(j, i)), 1e-8);
96  else
97  BOOST_REQUIRE_CLOSE(double(origSlice(j, i)), double(xSlice(j, i)),
98  1e-8);
99  }
100  }
101  }
102 }
103 
104 // Test all serialization strategies.
105 template<typename CubeType>
106 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
107 {
108  TestArmadilloSerialization<CubeType, boost::archive::xml_iarchive,
109  boost::archive::xml_oarchive>(x);
110  TestArmadilloSerialization<CubeType, boost::archive::text_iarchive,
111  boost::archive::text_oarchive>(x);
112  TestArmadilloSerialization<CubeType, boost::archive::binary_iarchive,
113  boost::archive::binary_oarchive>(x);
114 }
115 
116 // Test function for loading and saving Armadillo objects.
117 template<typename MatType,
118  typename IArchiveType,
119  typename OArchiveType>
121 {
122  // First save it.
123  std::string fileName = FilterFileName(typeid(IArchiveType).name());
124  std::ofstream ofs(fileName, std::ios::binary);
125  bool success = true;
126 
127  {
128  OArchiveType o(ofs);
129 
130  try
131  {
132  o << BOOST_SERIALIZATION_NVP(x);
133  }
134  catch (boost::archive::archive_exception& e)
135  {
136  success = false;
137  }
138  }
139 
140  BOOST_REQUIRE_EQUAL(success, true);
141  ofs.close();
142 
143  // Now load it.
144  MatType orig(x);
145  success = true;
146  std::ifstream ifs(fileName, std::ios::binary);
147 
148  {
149  IArchiveType i(ifs);
150 
151  try
152  {
153  i >> BOOST_SERIALIZATION_NVP(x);
154  }
155  catch (boost::archive::archive_exception& e)
156  {
157  success = false;
158  }
159  }
160  ifs.close();
161 
162  remove(fileName.c_str());
163 
164  BOOST_REQUIRE_EQUAL(success, true);
165 
166  BOOST_REQUIRE_EQUAL(x.n_rows, orig.n_rows);
167  BOOST_REQUIRE_EQUAL(x.n_cols, orig.n_cols);
168  BOOST_REQUIRE_EQUAL(x.n_elem, orig.n_elem);
169 
170  for (size_t i = 0; i < x.n_cols; ++i)
171  for (size_t j = 0; j < x.n_rows; ++j)
172  if (double(orig(j, i)) == 0.0)
173  BOOST_REQUIRE_SMALL(double(x(j, i)), 1e-8);
174  else
175  BOOST_REQUIRE_CLOSE(double(orig(j, i)), double(x(j, i)), 1e-8);
176 }
177 
178 // Test all serialization strategies.
179 template<typename MatType>
181 {
182  TestArmadilloSerialization<MatType, boost::archive::xml_iarchive,
183  boost::archive::xml_oarchive>(x);
184  TestArmadilloSerialization<MatType, boost::archive::text_iarchive,
185  boost::archive::text_oarchive>(x);
186  TestArmadilloSerialization<MatType, boost::archive::binary_iarchive,
187  boost::archive::binary_oarchive>(x);
188 }
189 
190 // Save and load an mlpack object.
191 // The re-loaded copy is placed in 'newT'.
192 template<typename T, typename IArchiveType, typename OArchiveType>
193 void SerializeObject(T& t, T& newT)
194 {
195  std::string fileName = FilterFileName(typeid(T).name());
196  std::ofstream ofs(fileName, std::ios::binary);
197  bool success = true;
198 
199  {
200  OArchiveType o(ofs);
201 
202  try
203  {
204  o << BOOST_SERIALIZATION_NVP(t);
205  }
206  catch (boost::archive::archive_exception& e)
207  {
208  std::cerr << e.what() << std::endl;
209  success = false;
210  }
211  }
212  ofs.close();
213 
214  BOOST_REQUIRE_EQUAL(success, true);
215 
216  std::ifstream ifs(fileName, std::ios::binary);
217 
218  {
219  IArchiveType i(ifs);
220 
221  try
222  {
223  i >> BOOST_SERIALIZATION_NVP(newT);
224  }
225  catch (boost::archive::archive_exception& e)
226  {
227  std::cout << e.what() << "\n";
228  success = false;
229  }
230  }
231  ifs.close();
232 
233  remove(fileName.c_str());
234 
235  BOOST_REQUIRE_EQUAL(success, true);
236 }
237 
238 // Test mlpack serialization with all three archive types.
239 template<typename T>
240 void SerializeObjectAll(T& t, T& xmlT, T& textT, T& binaryT)
241 {
242  SerializeObject<T, boost::archive::xml_iarchive,
243  boost::archive::xml_oarchive>(t, xmlT);
244  SerializeObject<T, boost::archive::text_iarchive,
245  boost::archive::text_oarchive>(t, textT);
246  SerializeObject<T, boost::archive::binary_iarchive,
247  boost::archive::binary_oarchive>(t, binaryT);
248 }
249 
250 // Save and load a non-default-constructible mlpack object.
251 template<typename T, typename IArchiveType, typename OArchiveType>
252 void SerializePointerObject(T* t, T*& newT)
253 {
254  std::string fileName = FilterFileName(typeid(T).name());
255  std::ofstream ofs(fileName, std::ios::binary);
256  bool success = true;
257 
258  {
259  OArchiveType o(ofs);
260  try
261  {
262  o << BOOST_SERIALIZATION_NVP(t);
263  }
264  catch (boost::archive::archive_exception& e)
265  {
266  std::cout << e.what() << "\n";
267  success = false;
268  }
269  }
270  ofs.close();
271 
272  BOOST_REQUIRE_EQUAL(success, true);
273 
274  std::ifstream ifs(fileName, std::ios::binary);
275 
276  {
277  IArchiveType i(ifs);
278 
279  try
280  {
281  i >> BOOST_SERIALIZATION_NVP(newT);
282  }
283  catch (std::exception& e)
284  {
285  std::cout << e.what() << "\n";
286  success = false;
287  }
288  }
289  ifs.close();
290 
291  remove(fileName.c_str());
292 
293  BOOST_REQUIRE_EQUAL(success, true);
294 }
295 
296 template<typename T>
297 void SerializePointerObjectAll(T* t, T*& xmlT, T*& textT, T*& binaryT)
298 {
299  SerializePointerObject<T, boost::archive::text_iarchive,
300  boost::archive::text_oarchive>(t, textT);
301  SerializePointerObject<T, boost::archive::binary_iarchive,
302  boost::archive::binary_oarchive>(t, binaryT);
303  SerializePointerObject<T, boost::archive::xml_iarchive,
304  boost::archive::xml_oarchive>(t, xmlT);
305 }
306 
307 // Utility function to check the equality of two Armadillo matrices.
308 void CheckMatrices(const arma::mat& x,
309  const arma::mat& xmlX,
310  const arma::mat& textX,
311  const arma::mat& binaryX);
312 
313 void CheckMatrices(const arma::Mat<size_t>& x,
314  const arma::Mat<size_t>& xmlX,
315  const arma::Mat<size_t>& textX,
316  const arma::Mat<size_t>& binaryX);
317 
318 void CheckMatrices(const arma::cube& x,
319  const arma::cube& xmlX,
320  const arma::cube& textX,
321  const arma::cube& binaryX);
322 
323 } // namespace mlpack
324 
325 #endif
void SerializePointerObject(T *t, T *&newT)
.hpp
Definition: add_to_po.hpp:21
std::string FilterFileName(const std::string &inputString)
Definition: test_tools.hpp:161
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &textX, const arma::mat &binaryX)
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&textT, T *&binaryT)
void SerializeObjectAll(T &t, T &xmlT, T &textT, T &binaryT)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)