base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
32 
33 namespace mlpack {
34 namespace ann {
35 
64 template <
65  class ActivationFunction = LogisticFunction,
66  typename InputDataType = arma::mat,
67  typename OutputDataType = arma::mat
68 >
69 class BaseLayer
70 {
71  public:
76  {
77  // Nothing to do here.
78  }
79 
87  template<typename InputType, typename OutputType>
88  void Forward(const InputType& input, OutputType& output)
89  {
90  ActivationFunction::Fn(input, output);
91  }
92 
102  template<typename eT>
103  void Backward(const arma::Mat<eT>& input,
104  const arma::Mat<eT>& gy,
105  arma::Mat<eT>& g)
106  {
107  arma::Mat<eT> derivative;
108  ActivationFunction::Deriv(input, derivative);
109  g = gy % derivative;
110  }
111 
113  OutputDataType const& OutputParameter() const { return outputParameter; }
115  OutputDataType& OutputParameter() { return outputParameter; }
116 
118  OutputDataType const& Delta() const { return delta; }
120  OutputDataType& Delta() { return delta; }
121 
125  template<typename Archive>
126  void serialize(Archive& /* ar */, const uint32_t /* version */)
127  {
128  /* Nothing to do here */
129  }
130 
131  private:
133  OutputDataType delta;
134 
136  OutputDataType outputParameter;
137 }; // class BaseLayer
138 
139 // Convenience typedefs.
140 
144 template <
145  class ActivationFunction = LogisticFunction,
146  typename InputDataType = arma::mat,
147  typename OutputDataType = arma::mat
148 >
149 using SigmoidLayer = BaseLayer<
150  ActivationFunction, InputDataType, OutputDataType>;
151 
155 template <
156  class ActivationFunction = IdentityFunction,
157  typename InputDataType = arma::mat,
158  typename OutputDataType = arma::mat
159 >
160 using IdentityLayer = BaseLayer<
161  ActivationFunction, InputDataType, OutputDataType>;
162 
166 template <
167  class ActivationFunction = RectifierFunction,
168  typename InputDataType = arma::mat,
169  typename OutputDataType = arma::mat
170 >
171 using ReLULayer = BaseLayer<
172  ActivationFunction, InputDataType, OutputDataType>;
173 
177 template <
178  class ActivationFunction = TanhFunction,
179  typename InputDataType = arma::mat,
180  typename OutputDataType = arma::mat
181 >
182 using TanHLayer = BaseLayer<
183  ActivationFunction, InputDataType, OutputDataType>;
184 
188 template <
189  class ActivationFunction = SoftplusFunction,
190  typename InputDataType = arma::mat,
191  typename OutputDataType = arma::mat
192 >
193 using SoftPlusLayer = BaseLayer<
194  ActivationFunction, InputDataType, OutputDataType>;
195 
199 template <
200  class ActivationFunction = HardSigmoidFunction,
201  typename InputDataType = arma::mat,
202  typename OutputDataType = arma::mat
203 >
205  ActivationFunction, InputDataType, OutputDataType>;
206 
210 template <
211  class ActivationFunction = SwishFunction,
212  typename InputDataType = arma::mat,
213  typename OutputDataType = arma::mat
214 >
216  ActivationFunction, InputDataType, OutputDataType>;
217 
221 template <
222  class ActivationFunction = MishFunction,
223  typename InputDataType = arma::mat,
224  typename OutputDataType = arma::mat
225 >
227  ActivationFunction, InputDataType, OutputDataType>;
228 
232 template <
233  class ActivationFunction = LiSHTFunction,
234  typename InputDataType = arma::mat,
235  typename OutputDataType = arma::mat
236 >
238  ActivationFunction, InputDataType, OutputDataType>;
239 
243 template <
244  class ActivationFunction = GELUFunction,
245  typename InputDataType = arma::mat,
246  typename OutputDataType = arma::mat
247 >
249  ActivationFunction, InputDataType, OutputDataType>;
250 
254 template <
255  class ActivationFunction = ElliotFunction,
256  typename InputDataType = arma::mat,
257  typename OutputDataType = arma::mat
258 >
260  ActivationFunction, InputDataType, OutputDataType>;
261 
265 template <
266  class ActivationFunction = ElishFunction,
267  typename InputDataType = arma::mat,
268  typename OutputDataType = arma::mat
269 >
271  ActivationFunction, InputDataType, OutputDataType>;
272 
276 template <
277  class ActivationFunction = GaussianFunction,
278  typename InputDataType = arma::mat,
279  typename OutputDataType = arma::mat
280 >
282  ActivationFunction, InputDataType, OutputDataType>;
283 
287 template <
288  class ActivationFunction = HardSwishFunction,
289  typename InputDataType = arma::mat,
290  typename OutputDataType = arma::mat
291 >
293  ActivationFunction, InputDataType, OutputDataType>;
294 
298 template <
299  class ActivationFunction = TanhExpFunction,
300  typename InputDataType = arma::mat,
301  typename OutputDataType = arma::mat
302 >
304  ActivationFunction, InputDataType, OutputDataType>;
305 
306 } // namespace ann
307 } // namespace mlpack
308 
309 #endif
The identity function, defined by.
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
Definition: base_layer.hpp:88
The Hard Swish function, defined by.
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: base_layer.hpp:115
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:75
The LiSHT function, defined by.
OutputDataType & Delta()
Modify the delta.
Definition: base_layer.hpp:120
void Backward(const arma::Mat< eT > &input, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Definition: base_layer.hpp:103
The tanh function, defined by.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void serialize(Archive &, const uint32_t)
Serialize the layer.
Definition: base_layer.hpp:126
The ELiSH function, defined by.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: base_layer.hpp:113
OutputDataType const & Delta() const
Get the delta.
Definition: base_layer.hpp:118
Implementation of the base layer.
Definition: base_layer.hpp:69
The Mish function, defined by.
The TanhExp function, defined by.
The logistic function, defined by.
The gaussian function, defined by.
The Elliot function, defined by.
The swish function, defined by.
The softplus function, defined by.
The hard sigmoid function, defined by.
The GELU function, defined by.
The rectifier function, defined by.