mean_imputation.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_IMPUTE_STRATEGIES_MEAN_IMPUTATION_HPP
13 #define MLPACK_CORE_DATA_IMPUTE_STRATEGIES_MEAN_IMPUTATION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace data {
23 template <typename T>
25 {
26  public:
37  void Impute(arma::Mat<T>& input,
38  const T& mappedValue,
39  const size_t dimension,
40  const bool columnMajor = true)
41  {
42  double sum = 0;
43  size_t elems = 0; // excluding nan or missing target
44 
45  using PairType = std::pair<size_t, size_t>;
46  // dimensions and indexes are saved as pairs inside this vector.
47  std::vector<PairType> targets;
48 
49 
50  // calculate number of elements and sum of them excluding mapped value or
51  // nan. while doing that, remember where mappedValue or NaN exists.
52  if (columnMajor)
53  {
54  for (size_t i = 0; i < input.n_cols; ++i)
55  {
56  if (input(dimension, i) == mappedValue ||
57  std::isnan(input(dimension, i)))
58  {
59  targets.emplace_back(dimension, i);
60  }
61  else
62  {
63  elems++;
64  sum += input(dimension, i);
65  }
66  }
67  }
68  else
69  {
70  for (size_t i = 0; i < input.n_rows; ++i)
71  {
72  if (input(i, dimension) == mappedValue ||
73  std::isnan(input(i, dimension)))
74  {
75  targets.emplace_back(i, dimension);
76  }
77  else
78  {
79  elems++;
80  sum += input(i, dimension);
81  }
82  }
83  }
84 
85  if (elems == 0)
86  Log::Fatal << "it is impossible to calculate mean; no valid elements in "
87  << "the dimension" << std::endl;
88 
89  // calculate mean;
90  const double mean = sum / elems;
91 
92  // Now replace the calculated mean to the missing variables
93  // It only needs to loop through targets vector, not the whole matrix.
94  for (const PairType& target : targets)
95  {
96  input(target.first, target.second) = mean;
97  }
98  }
99 }; // class MeanImputation
100 
101 } // namespace data
102 } // namespace mlpack
103 
104 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
constexpr auto data(Container const &container) noexcept -> decltype(container.data())
Definition: iterator.hpp:79
A simple mean imputation class.
static util::PrefixedOutStream Fatal
Definition: log.hpp:105
void Impute(arma::Mat< T > &input, const T &mappedValue, const size_t dimension, const bool columnMajor=true)
Impute function searches through the input looking for mappedValue and replaces it with the mean of t...