simple_tolerance_termination.hpp
Go to the documentation of this file.
1 
12 #ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
13 #define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace amf {
19 
30 template <class MatType>
32 {
33  public:
35  SimpleToleranceTermination(const double tolerance = 1e-5,
36  const size_t maxIterations = 10000,
37  const size_t reverseStepTolerance = 3) :
38  tolerance(tolerance),
39  maxIterations(maxIterations),
40  V(nullptr),
41  iteration(1),
42  residueOld(DBL_MAX),
43  residue(DBL_MIN),
44  reverseStepTolerance(reverseStepTolerance),
45  reverseStepCount(0),
46  isCopy(false),
47  c_indexOld(0),
48  c_index(0)
49  { }
50 
56  void Initialize(const MatType& V)
57  {
58  residueOld = DBL_MAX;
59  iteration = 1;
60  residue = DBL_MIN;
61  reverseStepCount = 0;
62  isCopy = false;
63 
64  this->V = &V;
65 
66  c_index = 0;
67  c_indexOld = 0;
68  }
69 
76  bool IsConverged(arma::mat& W, arma::mat& H)
77  {
78  arma::mat WH;
79 
80  WH = W * H;
81 
82  // Compute residue.
83  residueOld = residue;
84  size_t n = V->n_rows;
85  size_t m = V->n_cols;
86  double sum = 0;
87  size_t count = 0;
88  for (size_t i = 0; i < n; ++i)
89  {
90  for (size_t j = 0; j < m; ++j)
91  {
92  double temp = 0;
93  if ((temp = (*V)(i, j)) != 0)
94  {
95  temp = (temp - WH(i, j));
96  temp = temp * temp;
97  sum += temp;
98  count++;
99  }
100  }
101  }
102 
103  residue = sum;
104  if (count > 0)
105  residue /= count;
106  residue = sqrt(residue);
107 
108  // Increment iteration count.
109  iteration++;
110  Log::Info << "Iteration " << iteration << "; residue "
111  << ((residueOld - residue) / residueOld) << ".\n";
112 
113  // If residue tolerance is not satisfied.
114  if ((residueOld - residue) / residueOld < tolerance && iteration > 4)
115  {
116  // Check if this is a first of successive drops.
117  if (reverseStepCount == 0 && isCopy == false)
118  {
119  // Store a copy of W and H matrix.
120  isCopy = true;
121  this->W = W;
122  this->H = H;
123  // Store residue values.
124  c_index = residue;
125  c_indexOld = residueOld;
126  }
127  // Increase successive drop count.
128  reverseStepCount++;
129  }
130  // If tolerance is satisfied.
131  else
132  {
133  // Initialize successive drop count.
134  reverseStepCount = 0;
135  // If residue is droped below minimum scrap stored values.
136  if (residue <= c_indexOld && isCopy == true)
137  {
138  isCopy = false;
139  }
140  }
141 
142  // Check if termination criterion is met.
143  if (reverseStepCount == reverseStepTolerance || iteration > maxIterations)
144  {
145  // If stored values are present replace them with current value as they
146  // represent the minimum residue point.
147  if (isCopy)
148  {
149  W = this->W;
150  H = this->H;
151  residue = c_index;
152  }
153  return true;
154  }
155 
156  return false;
157  }
158 
160  const double& Index() const { return residue; }
161 
163  const size_t& Iteration() const { return iteration; }
164 
166  const size_t& MaxIterations() const { return maxIterations; }
167  size_t& MaxIterations() { return maxIterations; }
168 
170  const double& Tolerance() const { return tolerance; }
171  double& Tolerance() { return tolerance; }
172 
173  private:
175  double tolerance;
177  size_t maxIterations;
178 
180  const MatType* V;
181 
183  size_t iteration;
184 
186  double residueOld;
187  double residue;
188 
190  size_t reverseStepTolerance;
192  size_t reverseStepCount;
193 
196  bool isCopy;
197 
199  arma::mat W;
200  arma::mat H;
201  double c_indexOld;
202  double c_index;
203 }; // class SimpleToleranceTermination
204 
205 } // namespace amf
206 } // namespace mlpack
207 
208 #endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
209 
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)
empty constructor
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
void Initialize(const MatType &V)
Initializes the termination policy before stating the factorization.
const double & Tolerance() const
Access tolerance value.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.
This class implements residue tolerance termination policy.
auto count(Range &&rng, T const &value) -> enable_if_t< is_range< Range >::value, decltype(::std::count(::std::begin(::core::forward< Range >(rng)), ::std::end(::core::forward< Range >(rng)), value)) >
Definition: algorithm.hpp:225
const size_t & Iteration() const
Get current iteration count.
static util::PrefixedOutStream Info
Definition: log.hpp:93
const double & Index() const
Get current value of residue.
const size_t & MaxIterations() const
Access upper limit of iteration count.