oivs_init.hpp
Go to the documentation of this file.
1 
27 #ifndef MLPACK_METHODS_ANN_INIT_RULES_OIVS_INIT_HPP
28 #define MLPACK_METHODS_ANN_INIT_RULES_OIVS_INIT_HPP
29 
30 #include <mlpack/prereqs.hpp>
32 
33 #include "random_init.hpp"
34 
35 namespace mlpack {
36 namespace ann {
37 
56 template<class ActivationFunction = LogisticFunction>
58 {
59  public:
67  OivsInitialization(const double epsilon = 0.1,
68  const int k = 5,
69  const double gamma = 0.9) :
70  k(k), gamma(gamma),
71  b(std::abs(ActivationFunction::Inv(1 - epsilon) -
72  ActivationFunction::Inv(epsilon)))
73  {
74  }
75 
83  template<typename eT>
84  void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
85  {
86  RandomInitialization randomInit(-gamma, gamma);
87  randomInit.Initialize(W, rows, cols);
88 
89  W = (b / (k * rows)) * arma::sqrt(W + 1);
90  }
91 
97  template<typename eT>
98  void Initialize(arma::Mat<eT>& W)
99  {
100  RandomInitialization randomInit(-gamma, gamma);
101  randomInit.Initialize(W);
102 
103  W = (b / (k * W.n_rows)) * arma::sqrt(W + 1);
104  }
105 
115  template<typename eT>
116  void Initialize(arma::Cube<eT>& W,
117  const size_t rows,
118  const size_t cols,
119  const size_t slices)
120  {
121  if (W.is_empty())
122  W.set_size(rows, cols, slices);
123 
124  for (size_t i = 0; i < slices; ++i)
125  Initialize(W.slice(i), rows, cols);
126  }
127 
134  template<typename eT>
135  void Initialize(arma::Cube<eT>& W)
136  {
137  if (W.is_empty())
138  Log::Fatal << "Cannot initialize an empty cube." << std::endl;
139 
140  for (size_t i = 0; i < W.n_slices; ++i)
141  Initialize(W.slice(i));
142  }
143 
144  private:
146  int k;
147 
149  double gamma;
150 
152  double b;
153 }; // class OivsInitialization
154 
155 
156 } // namespace ann
157 } // namespace mlpack
158 
159 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:24
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class is used to initialize the weight matrix with the oivs method.
Definition: oivs_init.hpp:57
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize the elements of the specified weight matrix with the oivs method.
Definition: oivs_init.hpp:84
void Initialize(arma::Cube< eT > &W)
Initialize the elements of the specified weight 3rd order tensor with the oivs method.
Definition: oivs_init.hpp:135
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void Initialize(arma::Mat< eT > &W)
Initialize the elements of the specified weight matrix with the oivs method.
Definition: oivs_init.hpp:98
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize the elements of the specified weight 3rd order tensor with the oivs method.
Definition: oivs_init.hpp:116
OivsInitialization(const double epsilon=0.1, const int k=5, const double gamma=0.9)
Initialize the random initialization rule with the given values.
Definition: oivs_init.hpp:67