group_norm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_GROUPNORM_HPP
13 #define MLPACK_METHODS_ANN_LAYER_GROUPNORM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace ann {
19 
46 template <
47  typename InputDataType = arma::mat,
48  typename OutputDataType = arma::mat
49 >
50 class GroupNorm
51 {
52  public:
54  GroupNorm();
55 
62  GroupNorm(const size_t groupCount, const size_t size, const double eps = 1e-8);
63 
67  void Reset();
68 
77  template<typename eT>
78  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
79 
87  template<typename eT>
88  void Backward(const arma::Mat<eT>& input,
89  const arma::Mat<eT>& gy,
90  arma::Mat<eT>& g);
91 
99  template<typename eT>
100  void Gradient(const arma::Mat<eT>& input,
101  const arma::Mat<eT>& error,
102  arma::Mat<eT>& gradient);
103 
105  OutputDataType const& Parameters() const { return weights; }
107  OutputDataType& Parameters() { return weights; }
108 
110  OutputDataType const& OutputParameter() const { return outputParameter; }
112  OutputDataType& OutputParameter() { return outputParameter; }
113 
115  OutputDataType const& Delta() const { return delta; }
117  OutputDataType& Delta() { return delta; }
118 
120  OutputDataType const& Gradient() const { return gradient; }
122  OutputDataType& Gradient() { return gradient; }
123 
125  OutputDataType Mean() { return mean; }
126 
128  OutputDataType Variance() { return variance; }
129 
131  size_t InSize() const { return size; }
132 
134  double Epsilon() const { return eps; }
135 
137  size_t InputShape() const
138  {
139  return size;
140  }
141 
143  size_t GroupCount() const
144  {
145  return groupCount;
146  }
147 
151  template<typename Archive>
152  void serialize(Archive& ar, const uint32_t /* version */);
153 
154  private:
156  size_t groupCount;
157 
159  size_t size;
160 
162  double eps;
163 
165  bool loading;
166 
168  OutputDataType gamma;
169 
171  OutputDataType beta;
172 
174  OutputDataType weights;
175 
177  OutputDataType mean;
178 
180  OutputDataType variance;
181 
183  OutputDataType gradient;
184 
186  OutputDataType delta;
187 
189  OutputDataType outputParameter;
190 
192  OutputDataType normalized;
193 
195  OutputDataType temp;
196 
198  OutputDataType inputMean;
199 }; // class GroupNorm
200 
201 } // namespace ann
202 } // namespace mlpack
203 
204 // Include the implementation.
205 #include "group_norm_impl.hpp"
206 
207 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: group_norm.hpp:122
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Declaration of the Group Normalization class.
Definition: group_norm.hpp:50
void Reset()
Reset the layer parameters.
OutputDataType & Delta()
Modify the delta.
Definition: group_norm.hpp:117
size_t InSize() const
Get the number of input units.
Definition: group_norm.hpp:131
OutputDataType const & Gradient() const
Get the gradient.
Definition: group_norm.hpp:120
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: group_norm.hpp:112
OutputDataType & Parameters()
Modify the parameters.
Definition: group_norm.hpp:107
OutputDataType const & Parameters() const
Get the parameters.
Definition: group_norm.hpp:105
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Backward pass through the layer.
double Epsilon() const
Get the value of epsilon.
Definition: group_norm.hpp:134
OutputDataType const & Delta() const
Get the delta.
Definition: group_norm.hpp:115
size_t GroupCount() const
Get the group count.
Definition: group_norm.hpp:143
GroupNorm()
Create the GroupNorm object.
OutputDataType Variance()
Get the variance across single training data.
Definition: group_norm.hpp:128
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Forward pass of Group Normalization.
size_t InputShape() const
Get the shape of the input.
Definition: group_norm.hpp:137
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: group_norm.hpp:110
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
OutputDataType Mean()
Get the mean across single training data.
Definition: group_norm.hpp:125