glorot_init.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
15 #define MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include "random_init.hpp"
19 #include "gaussian_init.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
52 template<bool Uniform = true>
54 {
55  public:
60  {
61  // Nothing to do here.
62  }
63 
71  template<typename eT>
72  void Initialize(arma::Mat<eT>& W,
73  const size_t rows,
74  const size_t cols);
75 
81  template<typename eT>
82  void Initialize(arma::Mat<eT>& W);
83 
93  template<typename eT>
94  void Initialize(arma::Cube<eT>& W,
95  const size_t rows,
96  const size_t cols,
97  const size_t slices);
98 
105  template<typename eT>
106  void Initialize(arma::Cube<eT>& W);
107 
111  template<typename Archive>
112  void serialize(Archive& /* ar */, const uint32_t /* version */) { }
113 }; // class GlorotInitializationType
114 
115 template<>
116 template<typename eT>
117 inline void GlorotInitializationType<false>::Initialize(arma::Mat<eT>& W,
118  const size_t rows,
119  const size_t cols)
120 {
121  if (W.is_empty())
122  W.set_size(rows, cols);
123 
124  double var = 2.0 / double(rows + cols);
125  GaussianInitialization normalInit(0.0, var);
126  normalInit.Initialize(W, rows, cols);
127 }
128 
129 template<>
130 template<typename eT>
131 inline void GlorotInitializationType<false>::Initialize(arma::Mat<eT>& W)
132 {
133  if (W.is_empty())
134  Log::Fatal << "Cannot initialize and empty matrix." << std::endl;
135 
136  double var = 2.0 / double(W.n_rows + W.n_cols);
137  GaussianInitialization normalInit(0.0, var);
138  normalInit.Initialize(W);
139 }
140 
141 template<>
142 template<typename eT>
143 inline void GlorotInitializationType<true>::Initialize(arma::Mat<eT>& W,
144  const size_t rows,
145  const size_t cols)
146 {
147  if (W.is_empty())
148  W.set_size(rows, cols);
149 
150  // Limit of distribution.
151  double a = sqrt(6) / sqrt(rows + cols);
152  RandomInitialization randomInit(-a, a);
153  randomInit.Initialize(W, rows, cols);
154 }
155 
156 template<>
157 template<typename eT>
158 inline void GlorotInitializationType<true>::Initialize(arma::Mat<eT>& W)
159 {
160  if (W.is_empty())
161  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
162 
163  // Limit of distribution.
164  double a = sqrt(6) / sqrt(W.n_rows + W.n_cols);
165  RandomInitialization randomInit(-a, a);
166  randomInit.Initialize(W);
167 }
168 
169 template <bool Uniform>
170 template<typename eT>
171 inline void GlorotInitializationType<Uniform>::Initialize(arma::Cube<eT>& W,
172  const size_t rows,
173  const size_t cols,
174  const size_t slices)
175 {
176  if (W.is_empty())
177  W.set_size(rows, cols, slices);
178 
179  for (size_t i = 0; i < slices; ++i)
180  Initialize(W.slice(i), rows, cols);
181 }
182 
183 template <bool Uniform>
184 template<typename eT>
185 inline void GlorotInitializationType<Uniform>::Initialize(arma::Cube<eT>& W)
186 {
187  if (W.is_empty())
188  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
189 
190  for (size_t i = 0; i < W.n_slices; ++i)
191  Initialize(W.slice(i));
192 }
193 
194 // Convenience typedefs.
195 
200 
205 // Uses normal distribution
206 } // namespace ann
207 } // namespace mlpack
208 
209 #endif
void serialize(Archive &, const uint32_t)
Serialize the initialization.
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix using a Gaussian Distribution.
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:24
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
GlorotInitializationType()
Initialize the Glorot initialization object.
Definition: glorot_init.hpp:59
This class is used to initialize the weight matrix with the Glorot Initialization method...
Definition: glorot_init.hpp:53
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix with glorot initialization method.
This class is used to initialize weigth matrix with a gaussian.