concat.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_CONCAT_HPP
14 #define MLPACK_METHODS_ANN_LAYER_CONCAT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 #include "layer.hpp"
19 
20 namespace mlpack {
21 namespace ann {
22 
34 template <typename MatType = arma::mat>
35 class ConcatType : public MultiLayer<MatType>
36 {
37  public:
42  ConcatType();
43 
50  ConcatType(const size_t axis);
51 
55  virtual ~ConcatType();
56 
58  ConcatType* Clone() const { return new ConcatType(*this); }
59 
61  ConcatType(const ConcatType& other);
63  ConcatType(ConcatType&& other);
65  ConcatType& operator=(const ConcatType& other);
68 
76  void Forward(const MatType& input, MatType& output);
77 
87  void Backward(const MatType& /* input */,
88  const MatType& gy,
89  MatType& g);
90 
100  void Backward(const MatType& /* input */,
101  const MatType& gy,
102  MatType& g,
103  const size_t index);
104 
112  void Gradient(const MatType& /* input */,
113  const MatType& error,
114  MatType& /* gradient */);
115 
125  void Gradient(const MatType& input,
126  const MatType& error,
127  MatType& gradient,
128  const size_t index);
129 
131  size_t Axis() const { return axis; }
132 
133  // We don't need to overload WeightSize(); MultiLayer already computes this
134  // correctly. (It is the sum of weights of all child layers.)
135 
137  {
138  // The input is sent to every layer.
139  for (size_t i = 0; i < this->network.size(); ++i)
140  {
141  this->network[i]->InputDimensions() = this->inputDimensions;
142  this->network[i]->ComputeOutputDimensions();
143  }
144 
145  const size_t numOutputDimensions = (this->network.size() == 0) ?
146  this->inputDimensions.size() :
147  this->network[0]->OutputDimensions().size();
148 
149  // If the user did not specify an axis, we will use the last one.
150  // Otherwise, we must sanity check to ensure that the axis we are
151  // concatenating along is valid.
152  if (!useAxis)
153  {
154  axis = this->inputDimensions.size() - 1;
155  }
156  else if (axis >= numOutputDimensions)
157  {
158  std::ostringstream oss;
159  oss << "Concat::ComputeOutputDimensions(): cannot concatenate outputs "
160  << "along axis " << axis << " when input only has "
161  << this->inputDimensions.size() << " axes!";
162  throw std::invalid_argument(oss.str());
163  }
164 
165  // Now, we concatenate the output along a specific axis.
166  this->outputDimensions = std::vector<size_t>(numOutputDimensions, 0);
167  for (size_t i = 0; i < this->outputDimensions.size(); ++i)
168  {
169  if (i == axis)
170  {
171  // Accumulate output size along this axis for each layer output.
172  for (size_t n = 0; n < this->network.size(); ++n)
173  this->outputDimensions[i] += this->network[n]->OutputDimensions()[i];
174  }
175  else
176  {
177  // Ensure that the output size is the same along this axis.
178  const size_t axisDim = this->network[0]->OutputDimensions()[i];
179  for (size_t n = 1; n < this->network.size(); ++n)
180  {
181  const size_t axisDim2 = this->network[n]->OutputDimensions()[i];
182  if (axisDim != axisDim2)
183  {
184  std::ostringstream oss;
185  oss << "Concat::ComputeOutputDimensions(): cannot concatenate "
186  << "outputs along axis " << axis << "; held layer " << n
187  << " has output size " << axisDim2 << " along axis " << i
188  << ", but the first held layer has output size " << axisDim
189  << "! All layers must have identical output size in any "
190  << "axis other than the concatenated axis.";
191  throw std::invalid_argument(oss.str());
192  }
193  }
194 
195  this->outputDimensions[i] = axisDim;
196  }
197  }
198 
199  // Recompute total input and output sizes. Note that we pass the input to
200  // each layer held in the network, so the "total" input size (which is used
201  // by the backwards pass to compute how much memory to use for holding
202  // deltas) should be the number of layers multiplied by the input size for
203  // each layer.
204  this->totalInputSize = 1;
205  this->totalOutputSize = 1;
206  for (size_t i = 0; i < this->inputDimensions.size(); ++i)
207  this->totalInputSize *= this->inputDimensions[i];
208  this->totalInputSize *= this->network.size();
209  for (size_t i = 0; i < this->outputDimensions.size(); ++i)
210  this->totalOutputSize *= this->outputDimensions[i];
211  }
212 
216  template<typename Archive>
217  void serialize(Archive& ar, const uint32_t /* version */);
218 
219  private:
221  size_t axis;
222 
224  bool useAxis;
225 }; // class ConcatType.
226 
227 // Standard Concat layer.
229 
230 } // namespace ann
231 } // namespace mlpack
232 
233 // Include implementation.
234 #include "concat_impl.hpp"
235 
236 #endif
std::vector< size_t > inputDimensions
Logical input dimensions of each point.
Definition: layer.hpp:302
void ComputeOutputDimensions()
Compute the output dimensions of the MultiLayer using InputDimensions().
Definition: concat.hpp:136
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const MatType &, const MatType &gy, MatType &g)
Ordinary feed backward pass of a neural network, using 3rd-order tensors as input, calculating the function f(x) by propagating x backwards through f.
The core includes that mlpack expects; standard C++ includes and Armadillo.
const std::vector< size_t > & OutputDimensions()
Get the output dimensions.
Definition: layer.hpp:231
virtual ~ConcatType()
Destroy the layers held by the model.
ConcatType< arma::mat > Concat
Definition: concat.hpp:228
std::vector< size_t > outputDimensions
Logical output dimensions of each point.
Definition: layer.hpp:310
void Forward(const MatType &input, MatType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
A "multi-layer" is a layer that is a wrapper around other layers.
Definition: multi_layer.hpp:34
ConcatType()
Create the Concat object.
ConcatType * Clone() const
Clone the ConcatType object. This handles polymorphism correctly.
Definition: concat.hpp:58
size_t Axis() const
Get the axis of concatenation.
Definition: concat.hpp:131
std::vector< Layer< MatType > * > network
The internally-held network.
Implementation of the Concat class.
Definition: concat.hpp:35
ConcatType & operator=(const ConcatType &other)
Copy the given ConcatType layer.
void Gradient(const MatType &, const MatType &error, MatType &)
Calculate the gradient using the output delta and the input activation.