fast_lstm.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
14 #define MLPACK_METHODS_ANN_LAYER_FAST_LSTM_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include <limits>
18 
19 namespace mlpack {
20 namespace ann {
21 
57 template <
58  typename InputDataType = arma::mat,
59  typename OutputDataType = arma::mat
60 >
61 class FastLSTM
62 {
63  public:
64  // Convenience typedefs.
65  typedef typename InputDataType::elem_type InputElemType;
66  typedef typename OutputDataType::elem_type ElemType;
67 
69  FastLSTM();
70 
78  FastLSTM(const size_t inSize,
79  const size_t outSize,
80  const size_t rho = std::numeric_limits<size_t>::max());
81 
89  template<typename InputType, typename OutputType>
90  void Forward(InputType&& input, OutputType&& output);
91 
101  template<typename InputType, typename ErrorType, typename GradientType>
102  void Backward(const InputType&& input,
103  ErrorType&& gy,
104  GradientType&& g);
105 
106  /*
107  * Reset the layer parameter.
108  */
109  void Reset();
110 
111  /*
112  * Resets the cell to accept a new input. This breaks the BPTT chain starts a
113  * new one.
114  *
115  * @param size The current maximum number of steps through time.
116  */
117  void ResetCell(const size_t size);
118 
119  /*
120  * Calculate the gradient using the output delta and the input activation.
121  *
122  * @param input The input parameter used for calculating the gradient.
123  * @param error The calculated error.
124  * @param gradient The calculated gradient.
125  */
126  template<typename InputType, typename ErrorType, typename GradientType>
127  void Gradient(InputType&& input,
128  ErrorType&& error,
129  GradientType&& gradient);
130 
132  size_t Rho() const { return rho; }
134  size_t& Rho() { return rho; }
135 
137  OutputDataType const& Parameters() const { return weights; }
139  OutputDataType& Parameters() { return weights; }
140 
142  OutputDataType const& OutputParameter() const { return outputParameter; }
144  OutputDataType& OutputParameter() { return outputParameter; }
145 
147  OutputDataType const& Delta() const { return delta; }
149  OutputDataType& Delta() { return delta; }
150 
152  OutputDataType const& Gradient() const { return grad; }
154  OutputDataType& Gradient() { return grad; }
155 
159  template<typename Archive>
160  void serialize(Archive& ar, const unsigned int /* version */);
161 
162  private:
169  template<typename InputType, typename OutputType>
170  void FastSigmoid(InputType&& input, OutputType&& sigmoids)
171  {
172  for (size_t i = 0; i < input.n_elem; ++i)
173  sigmoids(i) = FastSigmoid(input(i));
174  }
175 
182  ElemType FastSigmoid(const InputElemType data)
183  {
184  ElemType x = 0.5 * data;
185  ElemType z;
186  if (x >= 0)
187  {
188  if (x < 1.7)
189  z = (1.5 * x / (1 + x));
190  else if (x < 3)
191  z = (0.935409070603099 + 0.0458812946797165 * (x - 1.7));
192  else
193  z = 0.99505475368673;
194  }
195  else
196  {
197  ElemType xx = -x;
198  if (xx < 1.7)
199  z = -(1.5 * xx / (1 + xx));
200  else if (xx < 3)
201  z = -(0.935409070603099 + 0.0458812946797165 * (xx - 1.7));
202  else
203  z = -0.99505475368673;
204  }
205 
206  return 0.5 * (z + 1.0);
207  }
208 
210  size_t inSize;
211 
213  size_t outSize;
214 
216  size_t rho;
217 
219  size_t forwardStep;
220 
222  size_t backwardStep;
223 
225  size_t gradientStep;
226 
228  OutputDataType weights;
229 
231  OutputDataType prevOutput;
232 
234  size_t batchSize;
235 
237  size_t batchStep;
238 
241  size_t gradientStepIdx;
242 
244  OutputDataType cellActivationError;
245 
247  OutputDataType delta;
248 
250  OutputDataType grad;
251 
253  OutputDataType outputParameter;
254 
256  OutputDataType output2GateWeight;
257 
259  OutputDataType input2GateWeight;
260 
262  OutputDataType input2GateBias;
263 
265  OutputDataType gate;
266 
268  OutputDataType gateActivation;
269 
271  OutputDataType stateActivation;
272 
274  OutputDataType cell;
275 
277  OutputDataType cellActivation;
278 
280  OutputDataType forgetGateError;
281 
283  OutputDataType prevError;
284 
286  OutputDataType outParameter;
287 
289  size_t rhoSize;
290 
292  size_t bpttSteps;
293 }; // class FastLSTM
294 
295 } // namespace ann
296 } // namespace mlpack
297 
298 // Include implementation.
299 #include "fast_lstm_impl.hpp"
300 
301 #endif
OutputDataType & Gradient()
Modify the gradient.
Definition: fast_lstm.hpp:154
OutputDataType & Delta()
Modify the delta.
Definition: fast_lstm.hpp:149
void Forward(InputType &&input, OutputType &&output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputDataType::elem_type ElemType
Definition: fast_lstm.hpp:66
OutputDataType const & Parameters() const
Get the parameters.
Definition: fast_lstm.hpp:137
OutputDataType const & Gradient() const
Get the gradient.
Definition: fast_lstm.hpp:152
size_t & Rho()
Modify the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:134
FastLSTM()
Create the Fast LSTM object.
OutputDataType const & Delta() const
Get the delta.
Definition: fast_lstm.hpp:147
OutputDataType const & OutputParameter() const
Get the output parameter.
Definition: fast_lstm.hpp:142
InputDataType::elem_type InputElemType
Definition: fast_lstm.hpp:65
void Backward(const InputType &&input, ErrorType &&gy, GradientType &&g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
OutputDataType & OutputParameter()
Modify the output parameter.
Definition: fast_lstm.hpp:144
void ResetCell(const size_t size)
size_t Rho() const
Get the maximum number of steps to backpropagate through time (BPTT).
Definition: fast_lstm.hpp:132
void serialize(Archive &ar, const unsigned int)
Serialize the layer.
OutputDataType & Parameters()
Modify the parameters.
Definition: fast_lstm.hpp:139
An implementation of a faster version of the Fast LSTM network layer.
Definition: fast_lstm.hpp:61