recurrent_layer.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_ANN_LAYER_RECURRENT_LAYER_HPP
13 #define MLPACK_METHODS_ANN_LAYER_RECURRENT_LAYER_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include "layer.hpp"
17 
18 namespace mlpack {
19 namespace ann {
20 
35 template<typename MatType = arma::mat>
36 class RecurrentLayer : public Layer<MatType>
37 {
38  public:
43 
44  // Virtual destructor is required for classes using inheritance.
45  virtual ~RecurrentLayer() { }
46 
48  RecurrentLayer(const RecurrentLayer& other);
55 
64  virtual void ClearRecurrentState(
65  const size_t bpttSteps,
66  const size_t batchSize) = 0;
67 
69  size_t CurrentStep() const { return currentStep; }
73  size_t& CurrentStep() { return currentStep; }
74 
77  size_t PreviousStep() const { return previousStep; }
82  size_t& PreviousStep() { return previousStep; }
83 
87  bool HasPreviousStep() const { return previousStep != size_t(-1); }
88 
90  template<typename Archive>
91  void serialize(Archive& ar, const uint32_t /* version */);
92 
93  private:
96  size_t currentStep;
99  size_t previousStep;
100 };
101 
102 } // namespace ann
103 } // namespace mlpack
104 
105 #include "recurrent_layer_impl.hpp"
106 
107 #endif
size_t CurrentStep() const
Get the current step index to use in a forward or backward pass.
virtual void ClearRecurrentState(const size_t bpttSteps, const size_t batchSize)=0
ClearRecurrentState() is called before any forward pass of a recurrent network.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
size_t PreviousStep() const
Get the previous step index, representing the value of CurrentStep() in the previous call to Forward(...
void serialize(Archive &ar, const uint32_t)
Serialize the recurrent layer.
The RecurrentLayer provides a base layer for all layers that have recurrent functionality and store s...
size_t & CurrentStep()
Modify the current step index to use in a forward or backward pass.
RecurrentLayer & operator=(const RecurrentLayer &other)
Copy the given RecurrentLayer.
size_t & PreviousStep()
Modify the previous step index, representing the value of CurrentStep() in the previous call to Forwa...
A layer is an abstract class implementing common neural networks operations, such as convolution...
Definition: layer.hpp:52
RecurrentLayer()
Create the RecurrentLayer.
bool HasPreviousStep() const
If Forward() or Backward() has been called since ClearRecurrentState(), this will return true...