radial_basis_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LAYER_RADIAL_BASIS_FUNCTION_HPP
14 #define MLPACK_METHODS_ANN_LAYER_RADIAL_BASIS_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
18 
19 #include "layer.hpp"
20 
21 namespace mlpack {
22 namespace ann {
23 
45 template <
46  typename MatType = arma::mat,
47  typename Activation = GaussianFunction
48 >
49 class RBFType : public Layer<MatType>
50 {
51  public:
53  RBFType();
54 
63  RBFType(const size_t outSize,
64  MatType& centres,
65  double betas = 0);
66 
68  RBFType* Clone() const { return new RBFType(*this); }
69 
70  // Virtual destructor.
71  virtual ~RBFType() { }
72 
74  RBFType(const RBFType& other);
76  RBFType(RBFType&& other);
78  RBFType& operator=(const RBFType& other);
80  RBFType& operator=(RBFType&& other);
81 
88  void Forward(const MatType& input, MatType& output);
89 
93  void Backward(const MatType& /* input */,
94  const MatType& /* gy */,
95  MatType& /* g */);
96 
100 
102  size_t WeightSize() const { return 0; }
103 
107  template<typename Archive>
108  void serialize(Archive& ar, const uint32_t /* version */);
109 
110  private:
112  size_t outSize;
113 
115  double betas;
116 
118  MatType centres;
119 
121  MatType distances;
122 }; // class RBFType
123 
125 
126 } // namespace ann
127 } // namespace mlpack
128 
129 // Include implementation.
130 #include "radial_basis_function_impl.hpp"
131 
132 #endif
void ComputeOutputDimensions()
Compute the output dimensions of the layer given InputDimensions().
RBFType & operator=(const RBFType &other)
Copy the given RBFType layer.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
RBFType< arma::mat > RBF
RBFType()
Create the RBFType object.
RBFType * Clone() const
Clone the LinearType object. This handles polymorphism correctly.
void Backward(const MatType &, const MatType &, MatType &)
Ordinary feed backward pass of the radial basis function.
Implementation of the Radial Basis Function layer.
A layer is an abstract class implementing common neural networks operations, such as convolution...
Definition: layer.hpp:52
size_t WeightSize() const
Get the size of the weights.
void Forward(const MatType &input, MatType &output)
Ordinary feed forward pass of the radial basis function.