he_init.hpp
Go to the documentation of this file.
1 
16 #ifndef MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
17 #define MLPACK_METHODS_ANN_INIT_RULES_HE_INIT_HPP
18 
19 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace ann {
24 
46 {
47  public:
52  {
53  // Nothing to do here.
54  }
55 
64  template <typename eT>
65  void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
66  {
67  // He initialization rule says to initialize weights with random
68  // values taken from a gaussian distribution with mean = 0 and
69  // standard deviation = sqrt(2/rows), i.e. variance = (2/rows).
70  const double variance = 2.0 / (double) rows;
71 
72  if (W.is_empty())
73  W.set_size(rows, cols);
74 
75  // Multipling a random variable X with variance V(X) by some factor c,
76  // then the variance V(cX) = (c^2) * V(X).
77  W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
78  }
79 
86  template <typename eT>
87  void Initialize(arma::Mat<eT>& W)
88  {
89  // He initialization rule says to initialize weights with random
90  // values taken from a gaussian distribution with mean = 0 and
91  // standard deviation = sqrt(2 / rows), i.e. variance = (2 / rows).
92  const double variance = 2.0 / (double) W.n_rows;
93 
94  if (W.is_empty())
95  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
96 
97  // Multipling a random variable X with variance V(X) by some factor c,
98  // then the variance V(cX) = (c^2) * V(X).
99  W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
100  }
101 
111  template <typename eT>
112  void Initialize(arma::Cube<eT> & W,
113  const size_t rows,
114  const size_t cols,
115  const size_t slices)
116  {
117  if (W.is_empty())
118  W.set_size(rows, cols, slices);
119 
120  for (size_t i = 0; i < slices; ++i)
121  Initialize(W.slice(i), rows, cols);
122  }
123 
130  template <typename eT>
131  void Initialize(arma::Cube<eT> & W)
132  {
133  if (W.is_empty())
134  Log::Fatal << "Cannot initialize an empty matrix" << std::endl;
135 
136  for (size_t i = 0; i < W.n_slices; ++i)
137  Initialize(W.slice(i));
138  }
139 
140  template<typename Archive>
141  void serialize(Archive& /* ar */, const uint32_t /* version */)
142  {
143  // Nothing to do.
144  }
145 }; // class HeInitialization
146 
147 } // namespace ann
148 } // namespace mlpack
149 
150 #endif
HeInitialization()
Initialize the HeInitialization object.
Definition: he_init.hpp:51
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with He initialization rule...
Definition: he_init.hpp:131
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with He initialization rule...
Definition: he_init.hpp:112
This class is used to initialize weight matrix with the He initialization rule given by He et...
Definition: he_init.hpp:45
void serialize(Archive &, const uint32_t)
Definition: he_init.hpp:141
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the weight matrix with the He initialization rule.
Definition: he_init.hpp:87
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the weight matrix with the He initialization rule.
Definition: he_init.hpp:65
Miscellaneous math random-related routines.
static util::PrefixedOutStream Fatal
Definition: log.hpp:105
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3