lstm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_HPP
13 #define MLPACK_METHODS_ANN_LAYER_LSTM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <limits>
17 
18 #include "layer.hpp"
19 
20 namespace mlpack {
21 namespace ann {
22 
58 template<typename MatType = arma::mat>
59 class LSTMType : public RecurrentLayer<MatType>
60 {
61  public:
63  LSTMType();
64 
71  LSTMType(const size_t outSize);
72 
74  LSTMType* Clone() const { return new LSTMType(*this); }
75 
77  LSTMType(const LSTMType& other);
79  LSTMType(LSTMType&& other);
81  LSTMType& operator=(const LSTMType& other);
83  LSTMType& operator=(LSTMType&& other);
84 
85  virtual ~LSTMType() { }
86 
91  void SetWeights(typename MatType::elem_type* weightsPtr);
92 
100  void Forward(const MatType& input, MatType& output);
101 
111  void Backward(const MatType& input, const MatType& gy, MatType& g);
112 
113  /*
114  * Calculate the gradient using the output delta and the input activation.
115  *
116  * @param input The input parameter used for calculating the gradient.
117  * @param error The calculated error.
118  * @param gradient The calculated gradient.
119  */
120  void Gradient(const MatType& input,
121  const MatType& error,
122  MatType& gradient);
123 
131  void ClearRecurrentState(const size_t bpttSteps, const size_t batchSize);
132 
134  const MatType& Parameters() const { return weights; }
136  MatType& Parameters() { return weights; }
137 
139  size_t WeightSize() const
140  {
141  return (4 * outSize * inSize + 7 * outSize + 4 * outSize * outSize);
142  }
143 
146  {
147  inSize = std::accumulate(this->inputDimensions.begin(),
148  this->inputDimensions.end(), 0);
149  this->outputDimensions = std::vector<size_t>(this->inputDimensions.size(),
150  1);
151 
152  // The LSTM layer flattens its input.
153  this->outputDimensions[0] = outSize;
154  }
155 
159  template<typename Archive>
160  void serialize(Archive& ar, const uint32_t /* version */);
161 
162  private:
164  size_t inSize;
165 
167  size_t outSize;
168 
170  MatType weights;
171 
173  MatType output2GateInputWeight;
174 
176  MatType input2GateInputWeight;
177 
179  MatType input2GateInputBias;
180 
182  MatType cell2GateInputWeight;
183 
185  MatType output2GateForgetWeight;
186 
188  MatType input2GateForgetWeight;
189 
191  MatType input2GateForgetBias;
192 
194  MatType cell2GateForgetWeight;
195 
197  MatType output2GateOutputWeight;
198 
200  MatType input2GateOutputWeight;
201 
203  MatType input2GateOutputBias;
204 
206  MatType cell2GateOutputWeight;
207 
208  // Below here are recurrent state matrices.
209 
211  MatType inputGate;
212 
214  MatType forgetGate;
215 
217  MatType hiddenLayer;
218 
220  MatType outputGate;
221 
223  MatType input2HiddenWeight;
224 
226  MatType input2HiddenBias;
227 
229  MatType output2HiddenWeight;
230 
232  arma::Cube<typename MatType::elem_type> cell;
233 
234  // These members store recurrent state.
235 
237  arma::Cube<typename MatType::elem_type> inputGateActivation;
238 
240  arma::Cube<typename MatType::elem_type> forgetGateActivation;
241 
243  arma::Cube<typename MatType::elem_type> outputGateActivation;
244 
246  arma::Cube<typename MatType::elem_type> hiddenLayerActivation;
247 
249  arma::Cube<typename MatType::elem_type> cellActivation;
250 
252  MatType forgetGateError;
253 
255  MatType outputGateError;
256 
258  arma::Cube<typename MatType::elem_type> outParameter;
259 
261  MatType inputCellError;
262 
264  MatType inputGateError;
265 
267  MatType hiddenError;
268 }; // class LSTMType
269 
270 // Convenience typedefs.
271 
272 // Standard LSTM layer.
274 
275 } // namespace ann
276 } // namespace mlpack
277 
278 // Include implementation.
279 #include "lstm_impl.hpp"
280 
281 #endif
std::vector< size_t > inputDimensions
Logical input dimensions of each point.
Definition: layer.hpp:319
Linear algebra utility functions, generally performed on matrices or vectors.
LSTMType * Clone() const
Clone the LSTMType object. This handles polymorphism correctly.
Definition: lstm.hpp:74
void SetWeights(typename MatType::elem_type *weightsPtr)
Reset the layer parameter.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
virtual ~LSTMType()
Definition: lstm.hpp:85
LSTMType & operator=(const LSTMType &other)
Copy the given LSTMType object.
size_t WeightSize() const
Get the total number of trainable parameters.
Definition: lstm.hpp:139
void Forward(const MatType &input, MatType &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
std::vector< size_t > outputDimensions
Logical output dimensions of each point.
Definition: layer.hpp:327
Implementation of the LSTM module class.
Definition: lstm.hpp:59
The RecurrentLayer provides a base layer for all layers that have recurrent functionality and store s...
void ComputeOutputDimensions()
Given a properly set InputDimensions(), compute the output dimensions.
Definition: lstm.hpp:145
void ClearRecurrentState(const size_t bpttSteps, const size_t batchSize)
Reset the recurrent state of the LSTM layer, and allocate enough space to hold bpttSteps of previous ...
void Backward(const MatType &input, const MatType &gy, MatType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
LSTMType< arma::mat > LSTM
Definition: lstm.hpp:273
const MatType & Parameters() const
Get the parameters.
Definition: lstm.hpp:134
MatType & Parameters()
Modify the parameters.
Definition: lstm.hpp:136
LSTMType()
Create the LSTM object.
void Gradient(const MatType &input, const MatType &error, MatType &gradient)
Computing the gradient of the layer with respect to its own input.