layer_norm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LAYERNORM_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LAYERNORM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
61 template <
62  typename InputDataType = arma::mat,
63  typename OutputDataType = arma::mat
64 >
65 class LayerNorm
66 {
67  public:
69  LayerNorm();
70 
77  LayerNorm(const size_t size, const double eps = 1e-8);
78 
82  void Reset();
83 
92  template<typename eT>
93  void Forward(const arma::Mat<eT>&& input, arma::Mat<eT>&& output);
94 
102  template<typename eT>
103  void Backward(const arma::Mat<eT>&& input,
104  arma::Mat<eT>&& gy,
105  arma::Mat<eT>&& g);
106 
114  template<typename eT>
115  void Gradient(const arma::Mat<eT>&& input,
116  arma::Mat<eT>&& error,
117  arma::Mat<eT>&& gradient);
118 
120  OutputDataType const& Parameters() const { return weights; }
122  OutputDataType& Parameters() { return weights; }
123 
125  OutputDataType const& OutputParameter() const { return outputParameter; }
127  OutputDataType& OutputParameter() { return outputParameter; }
128 
130  OutputDataType const& Delta() const { return delta; }
132  OutputDataType& Delta() { return delta; }
133 
135  OutputDataType const& Gradient() const { return gradient; }
137  OutputDataType& Gradient() { return gradient; }
138 
140  OutputDataType Mean() { return mean; }
141 
143  OutputDataType Variance() { return variance; }
144 
148  template<typename Archive>
149  void serialize(Archive& ar, const unsigned int /* version */);
150 
151  private:
153  size_t size;
154 
156  double eps;
157 
159  bool loading;
160 
162  OutputDataType gamma;
163 
165  OutputDataType beta;
166 
168  OutputDataType weights;
169 
171  OutputDataType mean;
172 
174  OutputDataType variance;
175 
177  OutputDataType gradient;
178 
180  OutputDataType delta;
181 
183  OutputDataType outputParameter;
184 
186  OutputDataType normalized;
187 
189  OutputDataType inputMean;
190 }; // class LayerNorm
191 
192 } // namespace ann
193 } // namespace mlpack
194 
195 // Include the implementation.
196 #include "layer_norm_impl.hpp"
197 
198 #endif
void Backward(const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
Backward pass through the layer.
OutputDataType Mean()
Get the mean across single training data.
Definition: layer_norm.hpp:140
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: layer_norm.hpp:125
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType const & Delta() const
Get the delta.
Definition: layer_norm.hpp:130
OutputDataType & Parameters()
Modify the parameters.
Definition: layer_norm.hpp:122
LayerNorm()
Create the LayerNorm object.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: layer_norm.hpp:127
OutputDataType & Delta()
Modify the delta.
Definition: layer_norm.hpp:132
Declaration of the Layer Normalization class.
Definition: layer_norm.hpp:65
OutputDataType & Gradient()
Modify the gradient.
Definition: layer_norm.hpp:137
OutputDataType const & Parameters() const
Get the parameters.
Definition: layer_norm.hpp:120
OutputDataType Variance()
Get the variance across single training data.
Definition: layer_norm.hpp:143
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType const & Gradient() const
Get the gradient.
Definition: layer_norm.hpp:135
void Forward(const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
Forward pass of Layer Normalization.
void Reset()
Reset the layer parameters.