batch_norm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BATCHNORM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
52 template <
53  typename InputDataType = arma::mat,
54  typename OutputDataType = arma::mat
55 >
56 class BatchNorm
57 {
58  public:
60  BatchNorm();
61 
68  BatchNorm(const size_t size, const double eps = 1e-8);
69 
73  void Reset();
74 
83  template<typename eT>
84  void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);
85 
93  template<typename eT>
94  void Backward(const arma::Mat<eT>&& input,
95  arma::Mat<eT>&& gy,
96  arma::Mat<eT>&& g);
97 
105  template<typename eT>
106  void Gradient(const arma::Mat<eT>&& input,
107  arma::Mat<eT>&& error,
108  arma::Mat<eT>&& gradient);
109 
111  OutputDataType const& Parameters() const { return weights; }
113  OutputDataType& Parameters() { return weights; }
114 
116  OutputDataType const& OutputParameter() const { return outputParameter; }
118  OutputDataType& OutputParameter() { return outputParameter; }
119 
121  OutputDataType const& Delta() const { return delta; }
123  OutputDataType& Delta() { return delta; }
124 
126  OutputDataType const& Gradient() const { return gradient; }
128  OutputDataType& Gradient() { return gradient; }
129 
131  bool Deterministic() const { return deterministic; }
133  bool& Deterministic() { return deterministic; }
134 
136  OutputDataType TrainingMean() { return runningMean; }
137 
139  OutputDataType TrainingVariance() { return runningVariance / count; }
140 
144  template<typename Archive>
145  void serialize(Archive& ar, const unsigned int /* version */);
146 
147  private:
149  size_t size;
150 
152  double eps;
153 
155  bool loading;
156 
158  OutputDataType gamma;
159 
161  OutputDataType beta;
162 
164  OutputDataType weights;
165 
170  bool deterministic;
171 
173  size_t count;
174 
176  OutputDataType mean;
177 
179  OutputDataType variance;
180 
182  OutputDataType runningMean;
183 
185  OutputDataType runningVariance;
186 
188  OutputDataType gradient;
189 
191  OutputDataType delta;
192 
194  OutputDataType outputParameter;
195 
197  OutputDataType normalized;
198 
200  OutputDataType inputMean;
201 }; // class BatchNorm
202 
203 } // namespace ann
204 } // namespace mlpack
205 
206 // Include the implementation.
207 #include "batch_norm_impl.hpp"
208 
209 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: batch_norm.hpp:128
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
void Forward(const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Forward pass of the Batch Normalization layer.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Backward pass through the layer.
OutputDataType & Delta()
Modify the delta.
Definition: batch_norm.hpp:123
OutputDataType TrainingMean()
Get the mean over the training data.
Definition: batch_norm.hpp:136
bool Deterministic() const
Get the value of deterministic parameter.
Definition: batch_norm.hpp:131
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: batch_norm.hpp:116
void Reset()
Reset the layer parameters.
OutputDataType & Parameters()
Modify the parameters.
Definition: batch_norm.hpp:113
bool & Deterministic()
Modify the value of deterministic parameter.
Definition: batch_norm.hpp:133
BatchNorm()
Create the BatchNorm object.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: batch_norm.hpp:118
OutputDataType const & Parameters() const
Get the parameters.
Definition: batch_norm.hpp:111
OutputDataType const & Gradient() const
Get the gradient.
Definition: batch_norm.hpp:126
Declaration of the Batch Normalization layer class.
Definition: batch_norm.hpp:56
OutputDataType TrainingVariance()
Get the variance over the training data.
Definition: batch_norm.hpp:139
OutputDataType const & Delta() const
Get the delta.
Definition: batch_norm.hpp:121