glorot_init.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
15 #define MLPACK_METHODS_ANN_INIT_RULES_GLOROT_INIT_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include "random_init.hpp"
19 #include "gaussian_init.hpp"
20 
21 using namespace mlpack::math;
22 
23 namespace mlpack {
24 namespace ann {
25 
54 template<bool Uniform = true>
56 {
57  public:
62  {
63  // Nothing to do here.
64  }
65 
73  template<typename eT>
74  void Initialize(arma::Mat<eT>& W,
75  const size_t rows,
76  const size_t cols);
77 
87  template<typename eT>
88  void Initialize(arma::Cube<eT>& W,
89  const size_t rows,
90  const size_t cols,
91  const size_t slices);
92 }; // class GlorotInitializationType
93 
94 template <>
95 template<typename eT>
96 inline void GlorotInitializationType<false>::Initialize(arma::Mat<eT>& W,
97  const size_t rows,
98  const size_t cols)
99 {
100  if (W.is_empty())
101  W = arma::mat(rows, cols);
102 
103  double var = 2.0/double(rows + cols);
104  GaussianInitialization normalInit(0.0, var);
105  normalInit.Initialize(W, rows, cols);
106 }
107 
108 template <>
109 template<typename eT>
110 inline void GlorotInitializationType<true>::Initialize(arma::Mat<eT>& W,
111  const size_t rows,
112  const size_t cols)
113 {
114  if (W.is_empty())
115  W = arma::mat(rows, cols);
116 
117  // Limit of distribution.
118  double a = sqrt(6) / sqrt(rows + cols);
119  RandomInitialization randomInit(-a, a);
120  randomInit.Initialize(W, rows, cols);
121 }
122 
123 template <bool Uniform>
124 template<typename eT>
125 inline void GlorotInitializationType<Uniform>::Initialize(arma::Cube<eT>& W,
126  const size_t rows,
127  const size_t cols,
128  const size_t slices)
129 {
130  if (W.is_empty())
131  {
132  W = arma::cube(rows, cols, slices);
133  }
134  for (size_t i = 0; i < slices; i++)
135  Initialize(W.slice(i), rows, cols);
136 }
137 
138 // Convenience typedefs.
139 
144 
149 // Uses normal distribution
150 } // namespace ann
151 } // namespace mlpack
152 
153 #endif
.hpp
Definition: add_to_po.hpp:21
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements weight matrix using a Gaussian Distribution.
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:24
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
Miscellaneous math routines.
Definition: ccov.hpp:20
GlorotInitializationType()
Initialize the Glorot initialization object.
Definition: glorot_init.hpp:61
This class is used to initialize the weight matrix with the Glorot Initialization method...
Definition: glorot_init.hpp:55
This class is used to initialize weigth matrix with a gaussian.