gelu_function.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
14 #define MLPACK_METHODS_ANN_ACTIVATION_FUNCTIONS_GELU_FUNCTION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
32 {
33  public:
40  static double Fn(const double x)
41  {
42  return 0.5 * x * (1 + std::tanh(std::sqrt(2 / M_PI) *
43  (x + 0.044715 * std::pow(x, 3))));
44  }
45 
52  template<typename InputVecType, typename OutputVecType>
53  static void Fn(const InputVecType& x, OutputVecType& y)
54  {
55  y = 0.5 * x % (1 + arma::tanh(std::sqrt(2 / M_PI) *
56  (x + 0.044715 * arma::pow(x, 3))));
57  }
58 
65  static double Deriv(const double y)
66  {
67  return 0.5 * std::tanh(0.0356774 * std::pow(y, 3) + 0.797885 * y) +
68  (0.0535161 * std::pow(y, 3) + 0.398942 * y) *
69  std::pow(1 / std::cosh(0.0356774 * std::pow(y, 3) +
70  0.797885 * y), 2) + 0.5;
71  }
72 
79  template<typename InputVecType, typename OutputVecType>
80  static void Deriv(const InputVecType& y, OutputVecType& x)
81  {
82  x = 0.5 * arma::tanh(0.0356774 * arma::pow(y, 3) + 0.797885 * y) +
83  (0.0535161 * arma::pow(y, 3) + 0.398942 * y) %
84  arma::pow(1 / arma::cosh(0.0356774 * arma::pow(y, 3) +
85  0.797885 * y), 2) + 0.5;
86  }
87 }; // class GELUFunction
88 
89 } // namespace ann
90 } // namespace mlpack
91 
92 #endif
static void Deriv(const InputVecType &y, OutputVecType &x)
Computes the first derivatives of the GELU function.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
static void Fn(const InputVecType &x, OutputVecType &y)
Computes the GELU function.
static double Deriv(const double y)
Computes the first derivative of the GELU function.
#define M_PI
Definition: base.hpp:43
static double Fn(const double x)
Computes the GELU function.
The GELU function, defined by.