serialization.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_SERIALIZATION_CATCH_HPP
13 #define MLPACK_TESTS_SERIALIZATION_CATCH_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 #include "test_catch_tools.hpp"
18 #include "catch.hpp"
19 
20 namespace mlpack {
21 
22 // Test function for loading and saving Armadillo objects.
23 template<typename CubeType,
24  typename IArchiveType,
25  typename OArchiveType>
26 void TestArmadilloSerialization(arma::Cube<CubeType>& x)
27 {
28  // First save it.
29  // Use type_info name to get unique file name for serialization test files.
30  std::string fileName = FilterFileName(typeid(IArchiveType).name());
31  std::ofstream ofs(fileName, std::ios::binary);
32  bool success = true;
33 
34  {
35  OArchiveType o(ofs);
36  o(CEREAL_NVP(x));
37  }
38 
39  REQUIRE(success == true);
40  ofs.close();
41 
42  // Now load it.
43  arma::Cube<CubeType> orig(x);
44  success = true;
45  std::ifstream ifs(fileName, std::ios::binary);
46 
47  {
48  IArchiveType i(ifs);
49  i(CEREAL_NVP(x));
50  }
51  ifs.close();
52 
53  remove(fileName.c_str());
54 
55  REQUIRE(success == true);
56 
57  REQUIRE(x.n_rows == orig.n_rows);
58  REQUIRE(x.n_cols == orig.n_cols);
59  REQUIRE(x.n_elem_slice == orig.n_elem_slice);
60  REQUIRE(x.n_slices == orig.n_slices);
61  REQUIRE(x.n_elem == orig.n_elem);
62 
63  for (size_t slice = 0; slice != x.n_slices; ++slice)
64  {
65  const auto& origSlice = orig.slice(slice);
66  const auto& xSlice = x.slice(slice);
67  for (size_t i = 0; i < x.n_cols; ++i)
68  {
69  for (size_t j = 0; j < x.n_rows; ++j)
70  {
71  if (double(origSlice(j, i)) == 0.0)
72  REQUIRE(double(xSlice(j, i)) == Approx(0.0).margin(1e-8 / 100));
73  else
74  REQUIRE(double(origSlice(j, i)) ==
75  Approx(double(xSlice(j, i))).epsilon(1e-8 / 100));
76  }
77  }
78  }
79 }
80 
81 // Test all serialization strategies.
82 template<typename CubeType>
83 void TestAllArmadilloSerialization(arma::Cube<CubeType>& x)
84 {
85  TestArmadilloSerialization<CubeType, cereal::XMLInputArchive,
86  cereal::XMLOutputArchive>(x);
87  TestArmadilloSerialization<CubeType, cereal::JSONInputArchive,
88  cereal::JSONOutputArchive>(x);
89  TestArmadilloSerialization<CubeType, cereal::BinaryInputArchive,
90  cereal::BinaryOutputArchive>(x);
91 }
92 
93 // Test function for loading and saving Armadillo objects.
94 template<typename MatType,
95  typename IArchiveType,
96  typename OArchiveType>
98 {
99  // First save it.
100  std::string fileName = FilterFileName(typeid(IArchiveType).name());
101  std::ofstream ofs(fileName, std::ios::binary);
102  bool success = true;
103 
104  {
105  OArchiveType o(ofs);
106  o(CEREAL_NVP(x));
107  }
108 
109  REQUIRE(success == true);
110  ofs.close();
111 
112  // Now load it.
113  MatType orig(x);
114  success = true;
115  std::ifstream ifs(fileName, std::ios::binary);
116 
117  {
118  IArchiveType i(ifs);
119  i(CEREAL_NVP(x));
120  }
121  ifs.close();
122 
123  remove(fileName.c_str());
124 
125  REQUIRE(success == true);
126 
127  REQUIRE(x.n_rows == orig.n_rows);
128  REQUIRE(x.n_cols == orig.n_cols);
129  REQUIRE(x.n_elem == orig.n_elem);
130 
131  for (size_t i = 0; i < x.n_cols; ++i)
132  for (size_t j = 0; j < x.n_rows; ++j)
133  if (double(orig(j, i)) == 0.0)
134  REQUIRE(double(x(j, i)) == Approx(0.0).margin(1e-8 / 100));
135  else
136  REQUIRE(double(orig(j, i)) ==
137  Approx(double(x(j, i))).epsilon(1e-8 / 100));
138 }
139 
140 // Test all serialization strategies.
141 template<typename MatType>
143 {
144  TestArmadilloSerialization<MatType, cereal::XMLInputArchive,
145  cereal::XMLOutputArchive>(x);
146  TestArmadilloSerialization<MatType, cereal::JSONInputArchive,
147  cereal::JSONOutputArchive>(x);
148  TestArmadilloSerialization<MatType, cereal::BinaryInputArchive,
149  cereal::BinaryOutputArchive>(x);
150 }
151 
152 // Save and load an mlpack object.
153 // The re-loaded copy is placed in 'newT'.
154 template<typename T, typename IArchiveType, typename OArchiveType>
155 void SerializeObject(T& t, T& newT)
156 {
157  std::string fileName = FilterFileName(typeid(T).name());
158  std::ofstream ofs(fileName, std::ios::binary);
159  bool success = true;
160 
161  {
162  OArchiveType o(ofs);
163 
164  T& x(t);
165  o(CEREAL_NVP(x));
166  }
167  ofs.close();
168 
169  REQUIRE(success == true);
170 
171  std::ifstream ifs(fileName, std::ios::binary);
172 
173  {
174  IArchiveType i(ifs);
175  T& x(newT);
176  i(CEREAL_NVP(x));
177  }
178  ifs.close();
179 
180  remove(fileName.c_str());
181 
182  REQUIRE(success == true);
183 }
184 
185 // Test mlpack serialization with all three archive types.
186 template<typename T>
187 void SerializeObjectAll(T& t, T& xmlT, T& jsonT, T& binaryT)
188 {
189  SerializeObject<T, cereal::XMLInputArchive,
190  cereal::XMLOutputArchive>(t, xmlT);
191  SerializeObject<T, cereal::JSONInputArchive,
192  cereal::JSONOutputArchive>(t, jsonT);
193  SerializeObject<T, cereal::BinaryInputArchive,
194  cereal::BinaryOutputArchive>(t, binaryT);
195 }
196 
197 // Save and load a non-default-constructible mlpack object.
198 template<typename T, typename IArchiveType, typename OArchiveType>
199 void SerializePointerObject(T* t, T*& newT)
200 {
201  std::string fileName = FilterFileName(typeid(T).name());
202  std::ofstream ofs(fileName, std::ios::binary);
203  bool success = true;
204 
205  {
206  OArchiveType o(ofs);
207  o(CEREAL_POINTER(t));
208  }
209  ofs.close();
210 
211  REQUIRE(success == true);
212 
213  std::ifstream ifs(fileName, std::ios::binary);
214 
215  {
216  IArchiveType i(ifs);
217  i(CEREAL_POINTER(newT));
218  }
219  ifs.close();
220  remove(fileName.c_str());
221 
222  REQUIRE(success == true);
223 }
224 
225 template<typename T>
226 void SerializePointerObjectAll(T* t, T*& xmlT, T*& jsonT, T*& binaryT)
227 {
228  SerializePointerObject<T, cereal::JSONInputArchive,
229  cereal::JSONOutputArchive>(t, jsonT);
230  SerializePointerObject<T, cereal::BinaryInputArchive,
231  cereal::BinaryOutputArchive>(t, binaryT);
232  SerializePointerObject<T, cereal::XMLInputArchive,
233  cereal::XMLOutputArchive>(t, xmlT);
234 }
235 
236 // Utility function to check the equality of two Armadillo matrices.
237 void CheckMatrices(const arma::mat& x,
238  const arma::mat& xmlX,
239  const arma::mat& jsonX,
240  const arma::mat& binaryX);
241 
242 void CheckMatrices(const arma::Mat<size_t>& x,
243  const arma::Mat<size_t>& xmlX,
244  const arma::Mat<size_t>& jsonX,
245  const arma::Mat<size_t>& binaryX);
246 
247 void CheckMatrices(const arma::cube& x,
248  const arma::cube& xmlX,
249  const arma::cube& jsonX,
250  const arma::cube& binaryX);
251 
252 } // namespace mlpack
253 
254 #endif
void SerializePointerObject(T *t, T *&newT)
void CheckMatrices(const arma::mat &x, const arma::mat &xmlX, const arma::mat &jsonX, const arma::mat &binaryX)
Linear algebra utility functions, generally performed on matrices or vectors.
void TestArmadilloSerialization(arma::Cube< CubeType > &x)
void SerializePointerObjectAll(T *t, T *&xmlT, T *&jsonT, T *&binaryT)
std::string FilterFileName(const std::string &inputString)
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
void SerializeObject(T &t, T &newT)
void TestAllArmadilloSerialization(arma::Cube< CubeType > &x)
#define CEREAL_POINTER(T)
Cereal does not support the serialization of raw pointer.
void SerializeObjectAll(T &t, T &xmlT, T &jsonT, T &binaryT)