mlpack  git-master
sgd.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
16 #define MLPACK_CORE_OPTIMIZERS_SGD_SGD_HPP
17 
18 #include <mlpack/prereqs.hpp>
23 
24 namespace mlpack {
25 namespace optimization {
26 
84 template<typename UpdatePolicyType = VanillaUpdate,
85  typename DecayPolicyType = NoDecay>
86 class SGD
87 {
88  public:
110  SGD(const double stepSize = 0.01,
111  const size_t batchSize = 32,
112  const size_t maxIterations = 100000,
113  const double tolerance = 1e-5,
114  const bool shuffle = true,
115  const UpdatePolicyType& updatePolicy = UpdatePolicyType(),
116  const DecayPolicyType& decayPolicy = DecayPolicyType(),
117  const bool resetPolicy = true);
118 
129  template<typename DecomposableFunctionType>
130  double Optimize(DecomposableFunctionType& function,
131  arma::mat& iterate);
132 
134  double StepSize() const { return stepSize; }
136  double& StepSize() { return stepSize; }
137 
139  size_t BatchSize() const { return batchSize; }
141  size_t& BatchSize() { return batchSize; }
142 
144  size_t MaxIterations() const { return maxIterations; }
146  size_t& MaxIterations() { return maxIterations; }
147 
149  double Tolerance() const { return tolerance; }
151  double& Tolerance() { return tolerance; }
152 
154  bool Shuffle() const { return shuffle; }
156  bool& Shuffle() { return shuffle; }
157 
160  bool ResetPolicy() const { return resetPolicy; }
163  bool& ResetPolicy() { return resetPolicy; }
164 
166  const UpdatePolicyType& UpdatePolicy() const { return updatePolicy; }
168  UpdatePolicyType& UpdatePolicy() { return updatePolicy; }
169 
171  const DecayPolicyType& DecayPolicy() const { return decayPolicy; }
173  DecayPolicyType& DecayPolicy() { return decayPolicy; }
174 
175  private:
177  double stepSize;
178 
180  size_t batchSize;
181 
183  size_t maxIterations;
184 
186  double tolerance;
187 
190  bool shuffle;
191 
193  UpdatePolicyType updatePolicy;
194 
196  DecayPolicyType decayPolicy;
197 
200  bool resetPolicy;
201 };
202 
204 
206 
208 
209 } // namespace optimization
210 } // namespace mlpack
211 
212 // Include implementation.
213 #include "sgd_impl.hpp"
214 
215 #endif
const DecayPolicyType & DecayPolicy() const
Get the step size decay policy.
Definition: sgd.hpp:171
bool Shuffle() const
Get whether or not the individual functions are shuffled.
Definition: sgd.hpp:154
const UpdatePolicyType & UpdatePolicy() const
Get the update policy.
Definition: sgd.hpp:166
.hpp
Definition: add_to_po.hpp:21
size_t BatchSize() const
Get the batch size.
Definition: sgd.hpp:139
bool & Shuffle()
Modify whether or not the individual functions are shuffled.
Definition: sgd.hpp:156
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Tolerance() const
Get the tolerance for termination.
Definition: sgd.hpp:149
size_t MaxIterations() const
Get the maximum number of iterations (0 indicates no limit).
Definition: sgd.hpp:144
double StepSize() const
Get the step size.
Definition: sgd.hpp:134
DecayPolicyType & DecayPolicy()
Modify the step size decay policy.
Definition: sgd.hpp:173
size_t & MaxIterations()
Modify the maximum number of iterations (0 indicates no limit).
Definition: sgd.hpp:146
double & StepSize()
Modify the step size.
Definition: sgd.hpp:136
UpdatePolicyType & UpdatePolicy()
Modify the update policy.
Definition: sgd.hpp:168
bool ResetPolicy() const
Get whether or not the update policy parameters are reset before Optimize call.
Definition: sgd.hpp:160
SGD(const double stepSize=0.01, const size_t batchSize=32, const size_t maxIterations=100000, const double tolerance=1e-5, const bool shuffle=true, const UpdatePolicyType &updatePolicy=UpdatePolicyType(), const DecayPolicyType &decayPolicy=DecayPolicyType(), const bool resetPolicy=true)
Construct the SGD optimizer with the given function and parameters.
Stochastic Gradient Descent is a technique for minimizing a function which can be expressed as a sum ...
Definition: sgd.hpp:86
size_t & BatchSize()
Modify the batch size.
Definition: sgd.hpp:141
bool & ResetPolicy()
Modify whether or not the update policy parameters are reset before Optimize call.
Definition: sgd.hpp:163
double Optimize(DecomposableFunctionType &function, arma::mat &iterate)
Optimize the given function using stochastic gradient descent.
double & Tolerance()
Modify the tolerance for termination.
Definition: sgd.hpp:151