discrete_distribution.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_CORE_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
15 #define MLPACK_CORE_DISTRIBUTIONS_DISCRETE_DISTRIBUTION_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 #include <mlpack/core/util/log.hpp>
20 
21 namespace mlpack {
22 namespace distribution {
23 
47 {
48  public:
54  probabilities(std::vector<arma::vec>(1)){ /* Nothing to do. */ }
55 
64  DiscreteDistribution(const size_t numObservations) :
65  probabilities(std::vector<arma::vec>(1,
66  arma::ones<arma::vec>(numObservations) / numObservations))
67  { /* Nothing to do. */ }
68 
77  DiscreteDistribution(const arma::Col<size_t>& numObservations)
78  {
79  for (size_t i = 0; i < numObservations.n_elem; i++)
80  {
81  const size_t numObs = size_t(numObservations[i]);
82  if (numObs <= 0)
83  {
84  std::ostringstream oss;
85  oss << "number of observations for dimension " << i << " is 0, but "
86  << "must be greater than 0";
87  throw std::invalid_argument(oss.str());
88  }
89  probabilities.push_back(arma::ones<arma::vec>(numObs) / numObs);
90  }
91  }
92 
99  DiscreteDistribution(const std::vector<arma::vec>& probabilities)
100  {
101  for (size_t i = 0; i < probabilities.size(); i++)
102  {
103  arma::vec temp = probabilities[i];
104  double sum = accu(temp);
105  if (sum > 0)
106  this->probabilities.push_back(temp / sum);
107  else
108  {
109  this->probabilities.push_back(arma::ones<arma::vec>(temp.n_elem)
110  / temp.n_elem);
111  }
112  }
113  }
114 
118  size_t Dimensionality() const { return probabilities.size(); }
119 
128  double Probability(const arma::vec& observation) const
129  {
130  double probability = 1.0;
131  // Ensure the observation has the same dimension with the probabilities.
132  if (observation.n_elem != probabilities.size())
133  {
134  Log::Fatal << "DiscreteDistribution::Probability(): observation has "
135  << "incorrect dimension " << observation.n_elem << " but should have"
136  << " dimension " << probabilities.size() << "!" << std::endl;
137  }
138 
139  for (size_t dimension = 0; dimension < observation.n_elem; dimension++)
140  {
141  // Adding 0.5 helps ensure that we cast the floating point to a size_t
142  // correctly.
143  const size_t obs = size_t(observation(dimension) + 0.5);
144 
145  // Ensure that the observation is within the bounds.
146  if (obs >= probabilities[dimension].n_elem)
147  {
148  Log::Fatal << "DiscreteDistribution::Probability(): received "
149  << "observation " << obs << "; observation must be in [0, "
150  << probabilities[dimension].n_elem << "] for this distribution."
151  << std::endl;
152  }
153  probability *= probabilities[dimension][obs];
154  }
155 
156  return probability;
157  }
158 
167  double LogProbability(const arma::vec& observation) const
168  {
169  // TODO: consider storing log probabilities instead?
170  return log(Probability(observation));
171  }
172 
180  void Probability(const arma::mat& x, arma::vec& probabilities) const
181  {
182  probabilities.set_size(x.n_cols);
183  for (size_t i = 0; i < x.n_cols; i++)
184  probabilities(i) = Probability(x.unsafe_col(i));
185  }
186 
195  void LogProbability(const arma::mat& x, arma::vec& logProbabilities) const
196  {
197  logProbabilities.set_size(x.n_cols);
198  for (size_t i = 0; i < x.n_cols; i++)
199  logProbabilities(i) = log(Probability(x.unsafe_col(i)));
200  }
201 
209  arma::vec Random() const;
210 
218  void Train(const arma::mat& observations);
219 
229  void Train(const arma::mat& observations,
230  const arma::vec& probabilities);
231 
233  arma::vec& Probabilities(const size_t dim = 0) { return probabilities[dim]; }
235  const arma::vec& Probabilities(const size_t dim = 0) const
236  { return probabilities[dim]; }
237 
241  template<typename Archive>
242  void serialize(Archive& ar, const unsigned int /* version */)
243  {
244  ar & BOOST_SERIALIZATION_NVP(probabilities);
245  }
246 
247  private:
250  std::vector<arma::vec> probabilities;
251 };
252 
253 } // namespace distribution
254 } // namespace mlpack
255 
256 #endif
DiscreteDistribution()
Default constructor, which creates a distribution that has no observations.
.hpp
Definition: add_to_po.hpp:21
A discrete distribution where the only observations are discrete observations.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Definition: prereqs.hpp:55
arma::vec Random() const
Return a randomly generated observation (one-dimensional vector; one observation) according to the pr...
double LogProbability(const arma::vec &observation) const
Return the log probability of the given observation.
const arma::vec & Probabilities(const size_t dim=0) const
Modify the vector of probabilities for the given dimension.
static MLPACK_EXPORT util::PrefixedOutStream Fatal
Prints fatal messages prefixed with [FATAL], then terminates the program.
Definition: log.hpp:90
void serialize(Archive &ar, const unsigned int)
Serialize the distribution.
size_t Dimensionality() const
Get the dimensionality of the distribution.
DiscreteDistribution(const arma::Col< size_t > &numObservations)
Define the multidimensional discrete distribution as having numObservations possible observations...
void Train(const arma::mat &observations)
Estimate the probability distribution directly from the given observations.
Miscellaneous math random-related routines.
double Probability(const arma::vec &observation) const
Return the probability of the given observation.
arma::vec & Probabilities(const size_t dim=0)
Return the vector of probabilities for the given dimension.
DiscreteDistribution(const size_t numObservations)
Define the discrete distribution as having numObservations possible observations. ...
DiscreteDistribution(const std::vector< arma::vec > &probabilities)
Define the multidimensional discrete distribution as having the given probabilities for each observat...
void LogProbability(const arma::mat &x, arma::vec &logProbabilities) const
Returns the Log probability of the given matrix.
void Probability(const arma::mat &x, arma::vec &probabilities) const
Calculates the Discrete probability density function for each data point (column) in the given matrix...