random_init.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_INIT_RULES_RANDOM_INIT_HPP
14 #define MLPACK_METHODS_ANN_INIT_RULES_RANDOM_INIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
25 {
26  public:
34  RandomInitialization(const double lowerBound = -1,
35  const double upperBound = 1) :
36  lowerBound(lowerBound), upperBound(upperBound) { }
37 
45  RandomInitialization(const double bound) :
46  lowerBound(-std::abs(bound)), upperBound(std::abs(bound)) { }
47 
55  template<typename eT>
56  void Initialize(arma::Mat<eT>& W, const size_t rows, const size_t cols)
57  {
58  W = lowerBound + arma::randu<arma::Mat<eT>>(rows, cols) *
59  (upperBound - lowerBound);
60  }
61 
69  template<typename eT>
70  void Initialize(arma::Cube<eT>& W,
71  const size_t rows,
72  const size_t cols,
73  const size_t slices)
74  {
75  W = arma::Cube<eT>(rows, cols, slices);
76 
77  for (size_t i = 0; i < slices; i++)
78  Initialize(W.slice(i), rows, cols);
79  }
80 
81  private:
83  double lowerBound;
84 
86  double upperBound;
87 }; // class RandomInitialization
88 
89 } // namespace ann
90 } // namespace mlpack
91 
92 #endif
.hpp
Definition: add_to_po.hpp:21
This class is used to initialize randomly the weight matrix.
Definition: random_init.hpp:24
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: prereqs.hpp:55
RandomInitialization(const double lowerBound=-1, const double upperBound=1)
Initialize the random initialization rule with the given lower bound and upper bound.
Definition: random_init.hpp:34
void Initialize(arma::Mat< eT > &W, const size_t rows, const size_t cols)
Initialize randomly the elements of the specified weight matrix.
Definition: random_init.hpp:56
RandomInitialization(const double bound)
Initialize the random initialization rule with the given bound.
Definition: random_init.hpp:45
void Initialize(arma::Cube< eT > &W, const size_t rows, const size_t cols, const size_t slices)
Initialize randomly the elements of the specified weight 3rd order tensor.
Definition: random_init.hpp:70