recurrent.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_HPP
14 
15 #include <mlpack/core.hpp>
16 
17 #include "../visitor/delete_visitor.hpp"
18 #include "../visitor/delta_visitor.hpp"
19 #include "../visitor/copy_visitor.hpp"
20 #include "../visitor/output_parameter_visitor.hpp"
21 #include "../visitor/input_shape_visitor.hpp"
22 
23 #include "layer_types.hpp"
24 #include "add_merge.hpp"
25 #include "sequential.hpp"
26 
27 namespace mlpack {
28 namespace ann {
29 
39 template <
40  typename InputDataType = arma::mat,
41  typename OutputDataType = arma::mat,
42  typename... CustomLayers
43 >
44 class Recurrent
45 {
46  public:
51  Recurrent();
52 
54  Recurrent(const Recurrent&);
55 
65  template<typename StartModuleType,
66  typename InputModuleType,
67  typename FeedbackModuleType,
68  typename TransferModuleType>
69  Recurrent(const StartModuleType& start,
70  const InputModuleType& input,
71  const FeedbackModuleType& feedback,
72  const TransferModuleType& transfer,
73  const size_t rho);
74 
82  template<typename eT>
83  void Forward(const arma::Mat<eT>& input, arma::Mat<eT>& output);
84 
94  template<typename eT>
95  void Backward(const arma::Mat<eT>& /* input */,
96  const arma::Mat<eT>& gy,
97  arma::Mat<eT>& g);
98 
99  /*
100  * Calculate the gradient using the output delta and the input activation.
101  *
102  * @param input The input parameter used for calculating the gradient.
103  * @param error The calculated error.
104  * @param gradient The calculated gradient.
105  */
106  template<typename eT>
107  void Gradient(const arma::Mat<eT>& input,
108  const arma::Mat<eT>& error,
109  arma::Mat<eT>& /* gradient */);
110 
112  std::vector<LayerTypes<CustomLayers...> >& Model() { return network; }
113 
115  bool Deterministic() const { return deterministic; }
117  bool& Deterministic() { return deterministic; }
118 
120  OutputDataType const& Parameters() const { return parameters; }
122  OutputDataType& Parameters() { return parameters; }
123 
125  OutputDataType const& OutputParameter() const { return outputParameter; }
127  OutputDataType& OutputParameter() { return outputParameter; }
128 
130  OutputDataType const& Delta() const { return delta; }
132  OutputDataType& Delta() { return delta; }
133 
135  OutputDataType const& Gradient() const { return gradient; }
137  OutputDataType& Gradient() { return gradient; }
138 
140  size_t const& Rho() const { return rho; }
141 
143  size_t InputShape() const;
144 
148  template<typename Archive>
149  void serialize(Archive& ar, const uint32_t /* version */);
150 
151  private:
153  DeleteVisitor deleteVisitor;
154 
156  CopyVisitor<CustomLayers...> copyVisitor;
157 
159  LayerTypes<CustomLayers...> startModule;
160 
162  LayerTypes<CustomLayers...> inputModule;
163 
165  LayerTypes<CustomLayers...> feedbackModule;
166 
168  LayerTypes<CustomLayers...> transferModule;
169 
171  size_t rho;
172 
174  size_t forwardStep;
175 
177  size_t backwardStep;
178 
180  size_t gradientStep;
181 
183  bool deterministic;
184 
187  bool ownsLayer;
188 
190  OutputDataType parameters;
191 
193  LayerTypes<CustomLayers...> initialModule;
194 
196  LayerTypes<CustomLayers...> recurrentModule;
197 
199  std::vector<LayerTypes<CustomLayers...> > network;
200 
202  LayerTypes<CustomLayers...> mergeModule;
203 
205  DeltaVisitor deltaVisitor;
206 
208  OutputParameterVisitor outputParameterVisitor;
209 
211  std::vector<arma::mat> feedbackOutputParameter;
212 
214  OutputDataType delta;
215 
217  OutputDataType gradient;
218 
220  OutputDataType outputParameter;
221 
223  arma::mat recurrentError;
224 }; // class Recurrent
225 
226 } // namespace ann
227 } // namespace mlpack
228 
229 // Include implementation.
230 #include "recurrent_impl.hpp"
231 
232 #endif
DeleteVisitor executes the destructor of the instantiated object.
OutputDataType const & Delta() const
Get the delta.
Definition: recurrent.hpp:130
Linear algebra utility functions, generally performed on matrices or vectors.
std::vector< LayerTypes< CustomLayers... > > & Model()
Get the model modules.
Definition: recurrent.hpp:112
bool & Deterministic()
Modify the value of the deterministic parameter.
Definition: recurrent.hpp:117
This visitor is to support copy constructor for neural network module.
boost::variant< AdaptiveMaxPooling< arma::mat, arma::mat > *, AdaptiveMeanPooling< arma::mat, arma::mat > *, Add< arma::mat, arma::mat > *, AddMerge< arma::mat, arma::mat > *, AlphaDropout< arma::mat, arma::mat > *, AtrousConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, BaseLayer< LogisticFunction, arma::mat, arma::mat > *, BaseLayer< IdentityFunction, arma::mat, arma::mat > *, BaseLayer< TanhFunction, arma::mat, arma::mat > *, BaseLayer< SoftplusFunction, arma::mat, arma::mat > *, BaseLayer< RectifierFunction, arma::mat, arma::mat > *, BatchNorm< arma::mat, arma::mat > *, BilinearInterpolation< arma::mat, arma::mat > *, CELU< arma::mat, arma::mat > *, Concat< arma::mat, arma::mat > *, Concatenate< arma::mat, arma::mat > *, ConcatPerformance< NegativeLogLikelihood< arma::mat, arma::mat >, arma::mat, arma::mat > *, Constant< arma::mat, arma::mat > *, Convolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< FullConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, CReLU< arma::mat, arma::mat > *, DropConnect< arma::mat, arma::mat > *, Dropout< arma::mat, arma::mat > *, ELU< arma::mat, arma::mat > *, FastLSTM< arma::mat, arma::mat > *, FlexibleReLU< arma::mat, arma::mat > *, GRU< arma::mat, arma::mat > *, HardTanH< arma::mat, arma::mat > *, Join< arma::mat, arma::mat > *, LayerNorm< arma::mat, arma::mat > *, LeakyReLU< arma::mat, arma::mat > *, Linear< arma::mat, arma::mat, NoRegularizer > *, LinearNoBias< arma::mat, arma::mat, NoRegularizer > *, LogSoftMax< arma::mat, arma::mat > *, Lookup< arma::mat, arma::mat > *, LSTM< arma::mat, arma::mat > *, MaxPooling< arma::mat, arma::mat > *, MeanPooling< arma::mat, arma::mat > *, MiniBatchDiscrimination< arma::mat, arma::mat > *, MultiplyConstant< arma::mat, arma::mat > *, MultiplyMerge< arma::mat, arma::mat > *, NegativeLogLikelihood< arma::mat, arma::mat > *, NoisyLinear< arma::mat, arma::mat > *, Padding< arma::mat, arma::mat > *, PReLU< arma::mat, arma::mat > *, Softmax< arma::mat, arma::mat > *, SpatialDropout< arma::mat, arma::mat > *, TransposedConvolution< NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, NaiveConvolution< ValidConvolution >, arma::mat, arma::mat > *, WeightNorm< arma::mat, arma::mat > *, MoreTypes, CustomLayers *... > LayerTypes
OutputDataType const & Parameters() const
Get the parameters.
Definition: recurrent.hpp:120
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
size_t const & Rho() const
Get the number of steps to backpropagate through time.
Definition: recurrent.hpp:140
OutputDataType & Delta()
Modify the delta.
Definition: recurrent.hpp:132
OutputDataType const & Gradient() const
Get the gradient.
Definition: recurrent.hpp:135
void Forward(const arma::Mat< eT > &input, arma::Mat< eT > &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputDataType & Gradient()
Modify the gradient.
Definition: recurrent.hpp:137
OutputParameterVisitor exposes the output parameter of the given module.
OutputDataType & Parameters()
Modify the parameters.
Definition: recurrent.hpp:122
Recurrent()
Default constructor—this will create a Recurrent object that can&#39;t be used, so be careful! Make sure...
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
DeltaVisitor exposes the delta parameter of the given module.
bool Deterministic() const
The value of the deterministic parameter.
Definition: recurrent.hpp:115
void Backward(const arma::Mat< eT > &, const arma::Mat< eT > &gy, arma::Mat< eT > &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
size_t InputShape() const
Get the shape of the input.
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: recurrent.hpp:125
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: recurrent.hpp:127