poisson_nll_loss.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
14 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_POISSON_NLL_LOSS_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace ann {
20 
29 template<typename MatType = arma::mat>
31 {
32  public:
48  PoissonNLLLossType(const bool logInput = true,
49  const bool full = false,
50  const typename MatType::elem_type eps = 1e-08,
51  const bool reduction = true);
52 
61  typename MatType::elem_type Forward(const MatType& prediction,
62  const MatType& target);
63 
76  void Backward(const MatType& prediction,
77  const MatType& target,
78  MatType& loss);
79 
82  bool LogInput() const { return logInput; }
85  bool& LogInput() { return logInput; }
86 
89  bool Full() const { return full; }
92  bool& Full() { return full; }
93 
96  typename MatType::elem_type Eps() const { return eps; }
99  typename MatType::elem_type& Eps() { return eps; }
100 
103  bool Reduction() const { return reduction; }
105  bool& Reduction() { return reduction; }
106 
110  template<typename Archive>
111  void serialize(Archive& ar, const uint32_t /* version */);
112 
113  private:
115  template<typename eT>
116  void CheckProbs(const arma::Mat<eT>& probs)
117  {
118  for (size_t i = 0; i < probs.size(); ++i)
119  {
120  if (probs[i] > 1.0 || probs[i] < 0.0)
121  Log::Fatal << "Probabilities cannot be greater than 1 "
122  << "or smaller than 0." << std::endl;
123  }
124  }
125 
127  bool logInput;
128 
130  // approximation term.
131  bool full;
132 
134  typename MatType::elem_type eps;
135 
137  bool reduction;
138 }; // class PoissonNLLLossType
139 
140 // Default typedef for typical `arma::mat` usage.
142 
143 } // namespace ann
144 } // namespace mlpack
145 
146 // Include implementation.
147 #include "poisson_nll_loss_impl.hpp"
148 
149 #endif
Implementation of the Poisson negative log likelihood loss.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
MatType::elem_type Forward(const MatType &prediction, const MatType &target)
Computes the Poisson negative log likelihood Loss.
MatType::elem_type & Eps()
Modify the value of eps.
bool & Full()
Modify the value of full.
MatType::elem_type Eps() const
Get the value of eps.
bool Reduction() const
Get the reduction type, represented as boolean (false &#39;mean&#39; reduction, true &#39;sum&#39; reduction)...
PoissonNLLLossType(const bool logInput=true, const bool full=false, const typename MatType::elem_type eps=1e-08, const bool reduction=true)
Create the PoissonNLLLossType object.
void Backward(const MatType &prediction, const MatType &target, MatType &loss)
Ordinary feed backward pass of a neural network.
bool & Reduction()
Modify the type of reduction used.
void serialize(Archive &ar, const uint32_t)
Serialize the layer.
PoissonNLLLossType< arma::mat > PoissonNLLLoss
bool LogInput() const
Get the value of logInput.
static util::PrefixedOutStream Fatal
Definition: log.hpp:105
bool Full() const
Get the value of full.
bool & LogInput()
Modify the value of logInput.