lecun_normal_init.hpp
Go to the documentation of this file.
1
15 #ifndef MLPACK_METHODS_ANN_INIT_RULES_LECUN_NORMAL_INIT_HPP
16 #define MLPACK_METHODS_ANN_INIT_RULES_LECUN_NORMAL_INIT_HPP
17
18 #include <mlpack/prereqs.hpp>
20
21 namespace mlpack {
22 namespace ann {
23
50 {
51  public:
56  {
57  // Nothing to do here.
58  }
59
68  template <typename eT>
69  void Initialize(arma::Mat<eT>& W,
70  const size_t rows,
71  const size_t cols)
72  {
73  // He initialization rule says to initialize weights with random
74  // values taken from a gaussian distribution with mean = 0 and
75  // standard deviation = sqrt(1 / rows), i.e. variance = (1 / rows).
76  const double variance = 1.0 / ((double) rows);
77
78  if (W.is_empty())
79  W.set_size(rows, cols);
80
81  // Multipling a random variable X with variance V(X) by some factor c,
82  // then the variance V(cX) = (c ^ 2) * V(X).
83  W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
84  }
85
92  template <typename eT>
93  void Initialize(arma::Mat<eT>& W)
94  {
95  // He initialization rule says to initialize weights with random
96  // values taken from a gaussian distribution with mean = 0 and
97  // standard deviation = sqrt(1 / rows), i.e. variance = (1 / rows).
98  const double variance = 1.0 / (double) W.n_rows;
99
100  if (W.is_empty())
101  Log::Fatal << "Cannot initialize an empty matrix." << std::endl;
102
103  // Multipling a random variable X with variance V(X) by some factor c,
104  // then the variance V(cX) = (c ^ 2) * V(X).
105  W.imbue( [&]() { return sqrt(variance) * arma::randn(); } );
106  }
107
117  template <typename eT>
118  void Initialize(arma::Cube<eT> & W,
119  const size_t rows,
120  const size_t cols,
121  const size_t slices)
122  {
123  if (W.is_empty())
124  W.set_size(rows, cols, slices);
125
126  for (size_t i = 0; i < slices; ++i)
127  Initialize(W.slice(i), rows, cols);
128  }
129
136  template <typename eT>
137  void Initialize(arma::Cube<eT> & W)
138  {
139  if (W.is_empty())
140  Log::Fatal << "Cannot initialize an empty cube." << std::endl;
141
142  for (size_t i = 0; i < W.n_slices; ++i)
143  Initialize(W.slice(i));
144  }
145 }; // class LecunNormalInitialization
146
147 } // namespace ann
148 } // namespace mlpack
149
150 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with Lecun Normal initialization rul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class is used to initialize weight matrix with the Lecun Normalization initialization rule...
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with Lecun Normal initialization rul...
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the weight matrix with the Lecun Normal initialization rule...
LecunNormalInitialization()
Initialize the LecunNormalInitialization object.
Miscellaneous math random-related routines.
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the weight matrix with the Lecun Normal initialization rule...
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "\$
Definition: CMakeLists.txt:3