recurrent_attention.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP
13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_ATTENTION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 #include "layer_types.hpp"
18 #include "add_merge.hpp"
19 #include "sequential.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
24 // TODO: refactor recurrent layer
47 template <
48  typename InputType = arma::mat,
49  typename OutputType = arma::mat
50 >
51 class RecurrentAttention : public MultiLayer<InputType, OutputType>
52 {
53  public:
59 
68  template<typename RNNModuleType, typename ActionModuleType>
69  RecurrentAttention(const size_t outSize,
70  const RNNModuleType& rnn,
71  const ActionModuleType& action,
72  const size_t rho);
73 
81  void Forward(const InputType& input, OutputType& output);
82 
92  void Backward(const InputType& /* input */,
93  const OutputType& gy,
94  OutputType& g);
95 
96  /*
97  * Calculate the gradient using the output delta and the input activation.
98  *
99  * @param * (input) The input parameter used for calculating the gradient.
100  * @param * (error) The calculated error.
101  * @param * (gradient) The calculated gradient.
102  */
103  void Gradient(const InputType& /* input */,
104  const OutputType& /* error */,
105  OutputType& /* gradient */);
106 
108  size_t const& Rho() const { return rho; }
109 
113  template<typename Archive>
114  void serialize(Archive& ar, const uint32_t /* version */);
115 
116  private:
118  void IntermediateGradient()
119  {
120  intermediateGradient.zeros();
121 
122  // Gradient of the action module.
123  if (backwardStep == (rho - 1))
124  {
125  actionModule->Gradient(initialInput, actionError,
126  actionModule->Gradient());
127  }
128  else
129  {
130  actionModule->Gradient(actionModule->OutputParameter(), actionError,
131  actionModule->Gradient());
132  }
133 
134  // Gradient of the recurrent module.
135  rnnModule->Gradient(rnnModule->OutputParameter(), recurrentError,
136  rnnModule->Gradient());
137 
138  attentionGradient += intermediateGradient;
139  }
140 
142  size_t outSize;
143 
145  Layer<InputType, OutputType>* rnnModule;
146 
148  Layer<InputType, OutputType>* actionModule;
149 
151  size_t rho;
152 
154  size_t forwardStep;
155 
157  size_t backwardStep;
158 
160  OutputType parameters;
161 
163  std::vector<Layer<InputType, OutputType>*> network;
164 
166  std::vector<OutputType> feedbackOutputParameter;
167 
169  std::vector<OutputType> moduleOutputParameter;
170 
172  OutputType recurrentError;
173 
175  OutputType actionError;
176 
178  OutputType actionDelta;
179 
181  OutputType rnnDelta;
182 
184  InputType initialInput;
185 
187  OutputType attentionGradient;
188 
190  OutputType intermediateGradient;
191 }; // class RecurrentAttention
192 
193 } // namespace ann
194 } // namespace mlpack
195 
196 // Include implementation.
197 #include "recurrent_attention_impl.hpp"
198 
199 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
virtual void Gradient(const MatType &, const MatType &, MatType &)
Computing the gradient of the layer with respect to its own input.
Definition: layer.hpp:173
The core includes that mlpack expects; standard C++ includes and Armadillo.
This class implements the Recurrent Model for Visual Attention, using a variety of possible layer imp...
size_t const & Rho() const
Get the number of steps to backpropagate through time.
RecurrentAttention()
Default constructor: this will not give a usable RecurrentAttention object, so be sure to set all the...
A "multi-layer" is a layer that is a wrapper around other layers.
Definition: multi_layer.hpp:34
void Forward(const InputType &input, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void Backward(const InputType &, const OutputType &gy, OutputType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
void Gradient(const InputType &, const OutputType &, OutputType &)