rectifier_function.hpp
Go to the documentation of this file.
1 
23 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_RECTIFIER_FUNCTION_HPP
24 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_RECTIFIER_FUNCTION_HPP
25 
26 #include <mlpack/prereqs.hpp>
27 #include <algorithm>
28 
29 namespace mlpack {
30 namespace ann {
31 
46 {
47  public:
54  static double Fn(const double x)
55  {
56  return std::max(0.0, x);
57  }
58 
65  template<typename eT>
66  static void Fn(const arma::Mat<eT>& x, arma::Mat<eT>& y)
67  {
68  y.zeros(x.n_rows, x.n_cols);
69  y = arma::max(y, x);
70  }
71 
78  template<typename eT>
79  static void Fn(const arma::Cube<eT>& x, arma::Cube<eT>& y)
80  {
81  y.zeros(x.n_rows, x.n_cols, x.n_slices);
82  y = arma::max(y, x);
83  }
84 
91  static double Deriv(const double y)
92  {
93  return (double)(y > 0);
94  }
95 
102  template<typename InputType, typename OutputType>
103  static void Deriv(const InputType& y, OutputType& x)
104  {
105  x.set_size(arma::size(y));
106 
107  for (size_t i = 0; i < y.n_elem; i++)
108  x(i) = Deriv(y(i));
109  }
110 }; // class RectifierFunction
111 
112 } // namespace ann
113 } // namespace mlpack
114 
115 #endif
static void Fn(const arma::Mat< eT > &x, arma::Mat< eT > &y)
Computes the rectifier function using a dense matrix as input.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Deriv(const InputType &y, OutputType &x)
Computes the first derivatives of the rectifier function.
static double Deriv(const double y)
Computes the first derivative of the rectifier function.
static double Fn(const double x)
Computes the rectifier function.
static void Fn(const arma::Cube< eT > &x, arma::Cube< eT > &y)
Computes the rectifier function using a 3rd-order tensor as input.
The rectifier function, defined by.