highway.hpp
Go to the documentation of this file.
1 // Temporarily drop.
14 #ifndef MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
15 #define MLPACK_METHODS_ANN_LAYER_HIGHWAY_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 #include "layer.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
48 template <
49  typename InputType = arma::mat,
50  typename OutputType = arma::mat
51 >
52 class HighwayType : public MultiLayer<InputType, OutputType>
53 {
54  public:
56  HighwayType();
57 
59  virtual ~HighwayType();
60 
62  HighwayType* Clone() const { return new HighwayType(*this); }
63 
65  HighwayType(const HighwayType& other);
67  HighwayType(HighwayType&& other);
69  HighwayType& operator=(const HighwayType& other);
72 
73  void SetWeights(typename OutputType::elem_type* weightsPtr);
74 
82  void Forward(const InputType& input, OutputType& output);
83 
93  void Backward(const InputType& /* input */,
94  const OutputType& gy,
95  OutputType& g);
96 
104  void Gradient(const InputType& input,
105  const OutputType& error,
106  OutputType& gradient);
107 
109  OutputType const& Parameters() const { return weights; }
111  OutputType& Parameters() { return weights; }
112 
114  size_t WeightSize() const
115  {
116  size_t result = this->totalInputSize * (this->totalInputSize + 1);
117  for (size_t i = 0; i < this->network.size(); ++i)
118  result += this->network[i]->WeightSize();
119  return result;
120  }
121 
125  template<typename Archive>
126  void serialize(Archive& ar, const uint32_t /* version */);
127 
128  private:
130  OutputType weights;
131 
133  OutputType transformWeight;
134 
136  OutputType transformBias;
137 
139  OutputType transformGate;
140 
142  OutputType transformGateActivation;
143 
145  OutputType transformGateError;
146 }; // class HighwayType
147 
148 // Standard Highway layer.
150 
151 } // namespace ann
152 } // namespace mlpack
153 
154 // Include implementation.
155 #include "highway_impl.hpp"
156 
157 #endif
OutputType const & Parameters() const
Get the parameters.
Definition: highway.hpp:109
Implementation of the Highway layer.
Definition: highway.hpp:52
Linear algebra utility functions, generally performed on matrices or vectors.
size_t WeightSize() const
Get the number of trainable weights.
Definition: highway.hpp:114
HighwayType< arma::mat, arma::mat > Highway
Definition: highway.hpp:149
The core includes that mlpack expects; standard C++ includes and Armadillo.
OutputType & Parameters()
Modify the parameters.
Definition: highway.hpp:111
HighwayType * Clone() const
Clone the HighwayType object. This handles polymorphism correctly.
Definition: highway.hpp:62
virtual ~HighwayType()
Destroy the HighwayType object.
void Gradient(const InputType &input, const OutputType &error, OutputType &gradient)
Calculate the gradient using the output delta and the input activation.
void SetWeights(typename OutputType::elem_type *weightsPtr)
A "multi-layer" is a layer that is a wrapper around other layers.
Definition: multi_layer.hpp:34
HighwayType & operator=(const HighwayType &other)
Copy the given HighwayType (but not weights).
void Backward(const InputType &, const OutputType &gy, OutputType &g)
Ordinary feed-backward pass of a neural network, calculating the function f(x) by propagating x backw...
HighwayType()
Create the HighwayType object.
std::vector< Layer< InputType > *> network
The internally-held network.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
void Forward(const InputType &input, OutputType &output)
Ordinary feed-forward pass of a neural network, evaluating the function f(x) by propagating the activ...