hard_swish_function.hpp
Go to the documentation of this file.
1 
24 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_HARD_SWISH_FUNCTION_HPP
25 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_HARD_SWISH_FUNCTION_HPP
26 
27 #include <mlpack/prereqs.hpp>
28 
29 namespace mlpack {
30 namespace ann {
48 {
49  public:
56  static double Fn(const double x)
57  {
58  if (x <= -3)
59  return 0;
60  else if (x >= 3)
61  return x;
62 
63  return x * (x + 3) / 6;
64  }
65 
72  template <typename InputVecType, typename OutputVecType>
73  static void Fn(const InputVecType &x, OutputVecType &y)
74  {
75  y.set_size(size(x));
76 
77  for (size_t i = 0; i < x.n_elem; i++)
78  y(i) = Fn(x(i));
79  }
80 
87  static double Deriv(const double y)
88  {
89  if (y <= -3)
90  return 0;
91  else if (y >= 3)
92  return 1;
93 
94  return (2 * y + 3.0) / 6.0;
95  }
96 
103  template <typename InputVecType, typename OutputVecType>
104  static void Deriv(const InputVecType &y, OutputVecType &x)
105  {
106  x.set_size(size(y));
107 
108  for (size_t i = 0; i < y.n_elem; i++)
109  x(i) = Deriv(y(i));
110  }
111 }; // class HardSwishFunction
112 
113 } // namespace ann
114 } // namespace mlpack
115 
116 #endif
The Hard Swish function, defined by.
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Definition: iterator.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes and Armadillo.
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the Hard Swish function.
static double Fn(const double x)
Computes the Hard Swish function.
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the Hard Swish function.
static double Deriv(const double y)
Computes the first derivative of the Hard Swish function.