linear3d.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_LINEAR3D_HPP
14 #define MLPACK_METHODS_ANN_LAYER_LINEAR3D_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 
19 #include "layer.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
34 template<
35  typename MatType = arma::mat,
36  typename RegularizerType = NoRegularizer
37 >
38 class Linear3DType : public Layer<MatType>
39 {
40  public:
42  Linear3DType();
43 
51  Linear3DType(const size_t outSize,
52  RegularizerType regularizer = RegularizerType());
53 
55  Linear3DType* Clone() const { return new Linear3DType(*this); }
56 
57  // Virtual destructor.
58  virtual ~Linear3DType() { }
59 
61  Linear3DType(const Linear3DType& other);
63  Linear3DType(Linear3DType&& other);
65  Linear3DType& operator=(const Linear3DType& other);
68 
69  /*
70  * Reset the layer parameter.
71  */
72  void SetWeights(typename MatType::elem_type* weightsPtr);
73 
81  void Forward(const MatType& input, MatType& output);
82 
92  void Backward(const MatType& /* input */,
93  const MatType& gy,
94  MatType& g);
95 
103  void Gradient(const MatType& input,
104  const MatType& error,
105  MatType& gradient);
106 
108  MatType const& Parameters() const { return weights; }
110  MatType& Parameters() { return weights; }
111 
113  MatType const& Weight() const { return weight; }
115  MatType& Weight() { return weight; }
116 
118  MatType const& Bias() const { return bias; }
120  MatType& Bias() { return bias; }
121 
123  size_t WeightSize() const { return outSize * (this->inputDimensions[0] + 1); }
124 
127 
131  template<typename Archive>
132  void serialize(Archive& ar, const uint32_t /* version */);
133 
134  private:
136  size_t outSize;
137 
139  MatType weights;
140 
142  MatType weight;
143 
145  MatType bias;
146 
148  RegularizerType regularizer;
149 }; // class Linear
150 
151 // Standard Linear3D layer.
153 
154 } // namespace ann
155 } // namespace mlpack
156 
157 // Include implementation.
158 #include "linear3d_impl.hpp"
159 
160 #endif
size_t WeightSize() const
Return the number of weight elements.
Definition: linear3d.hpp:123
void Gradient(const MatType &input, const MatType &error, MatType &gradient)
Calculate the gradient using the output delta and the input activation.
std::vector< size_t > inputDimensions
Logical input dimensions of each point.
Definition: layer.hpp:302
Linear3DType()
Create the Linear3D object.
void Forward(const MatType &input, MatType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
MatType & Weight()
Modify the weight of the layer.
Definition: linear3d.hpp:115
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear3D layer class.
Definition: linear3d.hpp:38
MatType const & Bias() const
Get the bias of the layer.
Definition: linear3d.hpp:118
MatType & Bias()
Modify the bias weights of the layer.
Definition: linear3d.hpp:120
MatType const & Weight() const
Get the weight of the layer.
Definition: linear3d.hpp:113
MatType & Parameters()
Modify the parameters.
Definition: linear3d.hpp:110
void ComputeOutputDimensions()
Compute the output dimensions for the layer, using InputDimensions().
Linear3DType & operator=(const Linear3DType &other)
Copy the given Linear3DType (but not weights).
Linear3DType< arma::mat, NoRegularizer > Linear3D
Definition: linear3d.hpp:152
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
void SetWeights(typename MatType::elem_type *weightsPtr)
Reset the layer parameter.
MatType const & Parameters() const
Get the parameters.
Definition: linear3d.hpp:108
A layer is an abstract class implementing common neural networks operations, such as convolution...
Definition: layer.hpp:52
void Backward(const MatType &, const MatType &gy, MatType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
Linear3DType * Clone() const
Clone the Linear3DType object. This handles polymorphism correctly.
Definition: linear3d.hpp:55