mean_normalization.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_DATA_MEAN_NORMALIZATION_HPP
13 #define MLPACK_CORE_DATA_MEAN_NORMALIZATION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace data {
19 
47 {
48  public:
54  template<typename MatType>
55  void Fit(const MatType& input)
56  {
57  itemMean = arma::mean(input, 1);
58  itemMin = arma::min(input, 1);
59  itemMax = arma::max(input, 1);
60  scale = itemMax - itemMin;
61  // Handling zeros in scale vector.
62  scale.for_each([](arma::vec::elem_type& val) { val =
63  (val == 0) ? 1 : val; });
64  }
65 
72  template<typename MatType>
73  void Transform(const MatType& input, MatType& output)
74  {
75  if (itemMean.is_empty() || scale.is_empty())
76  {
77  throw std::runtime_error("Call Fit() before Transform(), please"
78  " refer to the documentation.");
79  }
80  output.copy_size(input);
81  output = (input.each_col() - itemMean).each_col() / scale;
82  }
83 
90  template<typename MatType>
91  void InverseTransform(const MatType& input, MatType& output)
92  {
93  output.copy_size(input);
94  output = (input.each_col() % scale).each_col() + itemMean;
95  }
96 
98  const arma::vec& ItemMean() const { return itemMean; }
100  const arma::vec& ItemMin() const { return itemMin; }
102  const arma::vec& ItemMax() const { return itemMax; }
104  const arma::vec& Scale() const { return scale; }
105 
106  template<typename Archive>
107  void serialize(Archive& ar, const uint32_t /* version */)
108  {
109  ar(CEREAL_NVP(itemMin));
110  ar(CEREAL_NVP(itemMax));
111  ar(CEREAL_NVP(scale));
112  ar(CEREAL_NVP(itemMean));
113  }
114 
115  private:
116  // Vector which holds mean of each feature.
117  arma::vec itemMean;
118  // Vector which holds minimum of each feature.
119  arma::vec itemMin;
120  // Vector which holds maximum of each feature.
121  arma::vec itemMax;
122  // Vector which is used to scale up each feature.
123  arma::vec scale;
124 }; // class MeanNormalization
125 
126 } // namespace data
127 } // namespace mlpack
128 
129 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
const arma::vec & Scale() const
Get the Scale row vector.
const arma::vec & ItemMax() const
Get the Max row vector.
The core includes that mlpack expects; standard C++ includes and Armadillo.
A simple Mean Normalization class.
constexpr T const & max(T const &lhs, T const &rhs)
Definition: algorithm.hpp:79
const arma::vec & ItemMin() const
Get the Min row vector.
constexpr auto data(Container const &container) noexcept -> decltype(container.data())
Definition: iterator.hpp:79
void Transform(const MatType &input, MatType &output)
Function to scale features.
void Fit(const MatType &input)
Function to fit features, to find out the min max and scale.
constexpr T const & min(T const &lhs, T const &rhs)
Definition: algorithm.hpp:69
void serialize(Archive &ar, const uint32_t)
void InverseTransform(const MatType &input, MatType &output)
Function to retrieve original dataset.
const arma::vec & ItemMean() const
Get the Mean row vector.