12 #ifndef MLPACK_METHODS_ANN_LAYER_LSTM_HPP 13 #define MLPACK_METHODS_ANN_LAYER_LSTM_HPP 58 template<
typename MatType = arma::mat>
91 void SetWeights(
typename MatType::elem_type* weightsPtr);
100 void Forward(
const MatType& input, MatType& output);
111 void Backward(
const MatType& input,
const MatType& gy, MatType& g);
121 const MatType& error,
141 return (4 * outSize * inSize + 7 * outSize + 4 * outSize * outSize);
153 this->outputDimensions[0] = outSize;
159 template<
typename Archive>
160 void serialize(Archive& ar,
const uint32_t );
173 MatType output2GateInputWeight;
176 MatType input2GateInputWeight;
179 MatType input2GateInputBias;
182 MatType cell2GateInputWeight;
185 MatType output2GateForgetWeight;
188 MatType input2GateForgetWeight;
191 MatType input2GateForgetBias;
194 MatType cell2GateForgetWeight;
197 MatType output2GateOutputWeight;
200 MatType input2GateOutputWeight;
203 MatType input2GateOutputBias;
206 MatType cell2GateOutputWeight;
223 MatType input2HiddenWeight;
226 MatType input2HiddenBias;
229 MatType output2HiddenWeight;
232 arma::Cube<typename MatType::elem_type> cell;
237 arma::Cube<typename MatType::elem_type> inputGateActivation;
240 arma::Cube<typename MatType::elem_type> forgetGateActivation;
243 arma::Cube<typename MatType::elem_type> outputGateActivation;
246 arma::Cube<typename MatType::elem_type> hiddenLayerActivation;
249 arma::Cube<typename MatType::elem_type> cellActivation;
252 MatType forgetGateError;
255 MatType outputGateError;
258 arma::Cube<typename MatType::elem_type> outParameter;
261 MatType inputCellError;
264 MatType inputGateError;
279 #include "lstm_impl.hpp" std::vector< size_t > inputDimensions
Logical input dimensions of each point.
Linear algebra utility functions, generally performed on matrices or vectors.
LSTMType * Clone() const
Clone the LSTMType object. This handles polymorphism correctly.
void SetWeights(typename MatType::elem_type *weightsPtr)
Reset the layer parameter.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
LSTMType & operator=(const LSTMType &other)
Copy the given LSTMType object.
size_t WeightSize() const
Get the total number of trainable parameters.
void Forward(const MatType &input, MatType &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...
std::vector< size_t > outputDimensions
Logical output dimensions of each point.
Implementation of the LSTM module class.
The RecurrentLayer provides a base layer for all layers that have recurrent functionality and store s...
void ComputeOutputDimensions()
Given a properly set InputDimensions(), compute the output dimensions.
void ClearRecurrentState(const size_t bpttSteps, const size_t batchSize)
Reset the recurrent state of the LSTM layer, and allocate enough space to hold bpttSteps of previous ...
void Backward(const MatType &input, const MatType &gy, MatType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
LSTMType< arma::mat > LSTM
const MatType & Parameters() const
Get the parameters.
MatType & Parameters()
Modify the parameters.
LSTMType()
Create the LSTM object.
void Gradient(const MatType &input, const MatType &error, MatType &gradient)
Computing the gradient of the layer with respect to its own input.