mlpack  git-master
sarah.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_HPP
13 #define MLPACK_CORE_OPTIMIZERS_SARAH_SARAH_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 #include "sarah_update.hpp"
18 #include "sarah_plus_update.hpp"
19 
20 namespace mlpack {
21 namespace optimization {
22 
65 template<typename UpdatePolicyType = SARAHUpdate>
66 class SARAHType
67 {
68  public:
90  SARAHType(const double stepSize = 0.01,
91  const size_t batchSize = 32,
92  const size_t maxIterations = 1000,
93  const size_t innerIterations = 0,
94  const double tolerance = 1e-5,
95  const bool shuffle = true,
96  const UpdatePolicyType& updatePolicy = UpdatePolicyType());
97 
108  template<typename DecomposableFunctionType>
109  double Optimize(DecomposableFunctionType& function, arma::mat& iterate);
110 
112  double StepSize() const { return stepSize; }
114  double& StepSize() { return stepSize; }
115 
117  size_t BatchSize() const { return batchSize; }
119  size_t& BatchSize() { return batchSize; }
120 
122  size_t MaxIterations() const { return maxIterations; }
124  size_t& MaxIterations() { return maxIterations; }
125 
127  size_t InnerIterations() const { return innerIterations; }
129  size_t& InnerIterations() { return innerIterations; }
130 
132  double Tolerance() const { return tolerance; }
134  double& Tolerance() { return tolerance; }
135 
137  bool Shuffle() const { return shuffle; }
139  bool& Shuffle() { return shuffle; }
140 
142  const UpdatePolicyType& UpdatePolicy() const { return updatePolicy; }
144  UpdatePolicyType& UpdatePolicy() { return updatePolicy; }
145 
146  private:
148  double stepSize;
149 
151  size_t batchSize;
152 
154  size_t maxIterations;
155 
157  size_t innerIterations;
158 
160  double tolerance;
161 
164  bool shuffle;
165 
167  UpdatePolicyType updatePolicy;
168 };
169 
170 // Convenience typedefs.
171 
176 
181 
182 } // namespace optimization
183 } // namespace mlpack
184 
185 // Include implementation.
186 #include "sarah_impl.hpp"
187 
188 #endif
size_t MaxIterations() const
Get the maximum number of iterations (0 indicates no limit).
Definition: sarah.hpp:122
.hpp
Definition: add_to_po.hpp:21
double & Tolerance()
Modify the tolerance for termination.
Definition: sarah.hpp:134
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & InnerIterations()
Modify the maximum number of iterations (0 indicates default n / b).
Definition: sarah.hpp:129
size_t & MaxIterations()
Modify the maximum number of iterations (0 indicates no limit).
Definition: sarah.hpp:124
bool & Shuffle()
Modify whether or not the individual functions are shuffled.
Definition: sarah.hpp:139
bool Shuffle() const
Get whether or not the individual functions are shuffled.
Definition: sarah.hpp:137
UpdatePolicyType & UpdatePolicy()
Modify the update policy.
Definition: sarah.hpp:144
StochAstic Recusive gRadient algoritHm (SARAH).
Definition: sarah.hpp:66
size_t InnerIterations() const
Get the maximum number of iterations (0 indicates default n / b).
Definition: sarah.hpp:127
double Optimize(DecomposableFunctionType &function, arma::mat &iterate)
Optimize the given function using SARAH.
double Tolerance() const
Get the tolerance for termination.
Definition: sarah.hpp:132
SARAHType(const double stepSize=0.01, const size_t batchSize=32, const size_t maxIterations=1000, const size_t innerIterations=0, const double tolerance=1e-5, const bool shuffle=true, const UpdatePolicyType &updatePolicy=UpdatePolicyType())
Construct the SARAH optimizer with the given function and parameters.
const UpdatePolicyType & UpdatePolicy() const
Get the update policy.
Definition: sarah.hpp:142
size_t & BatchSize()
Modify the batch size.
Definition: sarah.hpp:119
double & StepSize()
Modify the step size.
Definition: sarah.hpp:114
double StepSize() const
Get the step size.
Definition: sarah.hpp:112
size_t BatchSize() const
Get the batch size.
Definition: sarah.hpp:117