reparametrization.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_REPARAMETRIZATION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer.hpp"
19 // #include "../activation_functions/softplus_function.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
53 template <
54  typename InputType = arma::mat,
55  typename OutputType = arma::mat
56 >
57 class ReparametrizationType : public Layer<InputType, OutputType>
58 {
59  public:
69  ReparametrizationType(const bool stochastic = true,
70  const bool includeKl = true,
71  const double beta = 1);
72 
78  {
79  return new ReparametrizationType(*this);
80  }
81 
82  // Virtual destructor.
83  virtual ~ReparametrizationType() { }
84 
87 
90 
93 
96 
111  void Forward(const InputType& input, OutputType& output);
112 
122  void Backward(const InputType& input,
123  const OutputType& gy,
124  OutputType& g);
125 
127  double Loss();
128 
130  bool Stochastic() const { return stochastic; }
132  bool& Stochastic() { return stochastic; }
133 
135  bool IncludeKL() const { return includeKl; }
137  bool& IncludeKL() { return includeKl; }
138 
140  double Beta() const { return beta; }
142  double& Beta() { return beta; }
143 
145  {
146  const size_t inputElem = std::accumulate(this->inputDimensions.begin(),
147  this->inputDimensions.end(), 0);
148  if (inputElem % 2 != 0)
149  {
150  std::ostringstream oss;
151  oss << "Reparametrization layer requires that the total number of input "
152  << "elements is divisible by 2! (Received input with " << inputElem
153  << " total elements.)";
154  throw std::invalid_argument(oss.str());
155  }
156 
157  this->outputDimensions = std::vector<size_t>(
158  this->inputDimensions.size(), 1);
159  // This flattens the input, and removes half the elements.
160  this->outputDimensions[0] = inputElem / 2;
161  }
162 
166  template<typename Archive>
167  void serialize(Archive& ar, const uint32_t /* version */);
168 
169  private:
171  bool stochastic;
172 
174  bool includeKl;
175 
177  double beta;
178 
180  OutputType gaussianSample;
181 
183  OutputType mean;
184 
187  OutputType preStdDev;
188 
190  OutputType stdDev;
191 }; // class ReparametrizationType
192 
193 // Standard Reparametrization layer.
195 
196 } // namespace ann
197 } // namespace mlpack
198 
199 // Include implementation.
200 #include "reparametrization_impl.hpp"
201 
202 #endif
bool IncludeKL() const
Get the value of the includeKl parameter.
void Backward(const InputType &input, const OutputType &gy, OutputType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
std::vector< size_t > inputDimensions
Logical input dimensions of each point.
Definition: layer.hpp:302
void ComputeOutputDimensions()
Compute the output dimensions.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Linear algebra utility functions, generally performed on matrices or vectors.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
The core includes that mlpack expects; standard C++ includes and Armadillo.
ReparametrizationType * Clone() const
Clone the ReparametrizationType object.
std::vector< size_t > outputDimensions
Logical output dimensions of each point.
Definition: layer.hpp:310
double Beta() const
Get the value of the beta hyperparameter.
bool Stochastic() const
Get the value of the stochastic parameter.
bool & IncludeKL()
Modify the value of the includeKl parameter.
bool & Stochastic()
Modify the value of the stochastic parameter.
ReparametrizationType(const bool stochastic=true, const bool includeKl=true, const double beta=1)
Create the Reparametrization layer object.
A layer is an abstract class implementing common neural networks operations, such as convolution...
Definition: layer.hpp:52
double Loss()
Get the KL divergence with standard normal.
ReparametrizationType & operator=(const ReparametrizationType &layer)
Copy assignment operator.
Implementation of the Reparametrization layer class.
double & Beta()
Modify the value of the beta hyperparameter.
ReparametrizationType< arma::mat, arma::mat > Reparametrization