multiply_merge.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP
14 #define MLPACK_METHODS_ANN_LAYER_MULTIPLY_MERGE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 // #include "../visitor/delete_visitor.hpp"
19 // #include "../visitor/delta_visitor.hpp"
20 // #include "../visitor/output_parameter_visitor.hpp"
21 
22 #include "layer_types.hpp"
23 
24 namespace mlpack {
25 namespace ann {
26 
36 template<
37  typename InputType = arma::mat,
38  typename OutputType = arma::mat
39 >
40 class MultiplyMergeType : public MultiLayer<InputType, OutputType>
41 {
42  public:
49  MultiplyMergeType(const bool model = false, const bool run = true);
50 
52  MultiplyMerge(const MultiplyMerge& layer);
53 
56 
58  MultiplyMerge& operator=(const MultiplyMerge& layer);
59 
62 
65 
67  MultiplyMergeType* Clone() const { return new MultiplyMergeType(*this); }
68 
76  void Forward(const InputType& /* input */, OutputType& output);
77 
87  void Backward(const InputType& /* input */,
88  const OutputType& gy,
89  OutputType& g);
90 
91  /*
92  * Calculate the gradient using the output delta and the input activation.
93  *
94  * @param input The input parameter used for calculating the gradient.
95  * @param error The calculated error.
96  * @param gradient The calculated gradient.
97  */
98  void Gradient(const InputType& input,
99  const OutputType& error,
100  OutputType& gradient);
101 
103  OutputType const& Parameters() const { return weights; }
105  OutputType& Parameters() { return weights; }
106 
108  size_t WeightSize() const { return 0; }
109 
113  template<typename Archive>
114  void serialize(Archive& ar, const uint32_t /* version */);
115 
116  private:
119  bool run;
120 
122  bool ownsLayer;
123 
125  OutputType weights;
126 }; // class MultiplyMergeType
127 
128 // Standard MultiplyMerge layer.
130 
131 } // namespace ann
132 } // namespace mlpack
133 
134 // Include implementation.
135 #include "multiply_merge_impl.hpp"
136 
137 #endif
MultiplyMergeType(const bool model=false, const bool run=true)
Create the MultiplyMerge object using the specified parameters.
Linear algebra utility functions, generally performed on matrices or vectors.
void Forward(const InputType &, OutputType &output)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
OutputType & Parameters()
Modify the parameters.
size_t WeightSize() const
Get the size of the weights.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the MultiplyMerge module class.
void Gradient(const InputType &input, const OutputType &error, OutputType &gradient)
OutputType const & Parameters() const
Get the parameters.
A "multi-layer" is a layer that is a wrapper around other layers.
Definition: multi_layer.hpp:34
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
MultiplyMerge(const MultiplyMerge &layer)
Copy Constructor.
MultiplyMergeType * Clone() const
Clone the MultiplyMergeType object. This handles polymorphism correctly.
void Backward(const InputType &, const OutputType &gy, OutputType &g)
Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backw...
MultiplyMerge & operator=(const MultiplyMerge &layer)
Copy assignment operator.
~MultiplyMergeType()
Destructor to release allocated memory.