base_layer.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
14 #define MLPACK_METHODS_ANN_LAYER_BASE_LAYER_HPP
15 
16 #include <mlpack/prereqs.hpp>
33 #include "layer.hpp"
34 
35 namespace mlpack {
36 namespace ann {
37 
62 template <
63  class ActivationFunction = LogisticFunction,
64  typename MatType = arma::mat
65 >
66 class BaseLayer : public Layer<MatType>
67 {
68  public:
72  BaseLayer() : Layer<MatType>()
73  {
74  // Nothing to do here.
75  }
76 
77  // Virtual destructor.
78  virtual ~BaseLayer() { }
79 
80  // No copy constructor or operators needed here, since the class has no
81  // members.
82 
84  BaseLayer* Clone() const { return new BaseLayer(*this); }
85 
92  void Forward(const MatType& input, MatType& output)
93  {
94  ActivationFunction::Fn(input, output);
95  }
96 
105  void Backward(const MatType& input, const MatType& gy, MatType& g)
106  {
107  MatType derivative;
108  ActivationFunction::Deriv(input, derivative);
109  g = gy % derivative;
110  }
111 
115  template<typename Archive>
116  void serialize(Archive& ar, const uint32_t /* version */)
117  {
118  ar(cereal::base_class<Layer<MatType>>(this));
119  // Nothing to serialize.
120  }
121 }; // class BaseLayer
122 
123 // Convenience typedefs.
124 
129 
130 template<typename MatType = arma::mat>
132 
137 
138 template<typename MatType = arma::mat>
140 
145 
146 template<typename MatType = arma::mat>
148 
153 
154 template<typename MatType = arma::mat>
156 
161 
162 template<typename MatType = arma::mat>
164 
169 
170 template<typename MatType = arma::mat>
172 
177 
178 template<typename MatType = arma::mat>
180 
185 
186 template<typename MatType = arma::mat>
188 
193 
194 template<typename MatType = arma::mat>
196 
201 
202 template<typename MatType = arma::mat>
204 
209 
210 template<typename MatType = arma::mat>
212 
217 
218 template<typename MatType = arma::mat>
220 
225 
226 template <typename MatType = arma::mat>
228 
233 
234 template<typename MatType = arma::mat>
236 
241 
242 template<typename MatType = arma::mat>
244 
245 } // namespace ann
246 } // namespace mlpack
247 
248 #endif
BaseLayer< GaussianFunction, arma::mat > Gaussian
Standard Gaussian-Layer using the Gaussian activation function.
Definition: base_layer.hpp:216
BaseLayer()
Create the BaseLayer object.
Definition: base_layer.hpp:72
BaseLayer< TanhExpFunction, arma::mat > TanhExp
Standard TanhExp-Layer using the TanhExp activation function.
Definition: base_layer.hpp:232
BaseLayer< RectifierFunction, arma::mat > ReLU
Standard rectified linear unit non-linearity layer.
Definition: base_layer.hpp:136
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Definition: base_layer.hpp:116
The core includes that mlpack expects; standard C++ includes and Armadillo.
BaseLayer< HardSigmoidFunction, arma::mat > HardSigmoid
Standard HardSigmoid-Layer using the HardSigmoid activation function.
Definition: base_layer.hpp:160
BaseLayer< MishFunction, arma::mat > Mish
Standard Mish-Layer using the Mish activation function.
Definition: base_layer.hpp:176
void Backward(const MatType &input, const MatType &gy, MatType &g)
Backward pass: compute the function f(x) by propagating x backwards through f, using the results from...
Definition: base_layer.hpp:105
BaseLayer< SoftplusFunction, arma::mat > SoftPlus
Standard Softplus-Layer using the Softplus activation function.
Definition: base_layer.hpp:152
Implementation of the base layer.
Definition: base_layer.hpp:66
BaseLayer< TanhFunction, arma::mat > TanH
Standard hyperbolic tangent layer.
Definition: base_layer.hpp:144
BaseLayer< ElishFunction, arma::mat > Elish
Standard ELiSH-Layer using the ELiSH activation function.
Definition: base_layer.hpp:208
BaseLayer< GELUFunction, arma::mat > GELU
Standard GELU-Layer using the GELU activation function.
Definition: base_layer.hpp:192
BaseLayer< HardSwishFunction, arma::mat > HardSwish
Standard HardSwish-Layer using the HardSwish activation function.
Definition: base_layer.hpp:224
BaseLayer< SwishFunction, arma::mat > Swish
Standard Swish-Layer using the Swish activation function.
Definition: base_layer.hpp:168
BaseLayer * Clone() const
Clone the BaseLayer object. This handles polymorphism correctly.
Definition: base_layer.hpp:84
BaseLayer< SILUFunction, arma::mat > SILU
Standard SILU-Layer using the SILU activation function.
Definition: base_layer.hpp:240
A layer is an abstract class implementing common neural networks operations, such as convolution...
Definition: layer.hpp:52
BaseLayer< ElliotFunction, arma::mat > Elliot
Standard Elliot-Layer using the Elliot activation function.
Definition: base_layer.hpp:200
BaseLayer< LogisticFunction, arma::mat > Sigmoid
Standard Sigmoid-Layer using the logistic activation function.
Definition: base_layer.hpp:128
BaseLayer< LiSHTFunction, arma::mat > LiSHT
Standard LiSHT-Layer using the LiSHT activation function.
Definition: base_layer.hpp:184
void Forward(const MatType &input, MatType &output)
Forward pass: apply the activation to the inputs.
Definition: base_layer.hpp:92