[mlpack] GSoC 2014 simulated annealing optimizer

Zhihao Lou lzh1984 at gmail.com
Mon Apr 14 11:09:17 EDT 2014


Hi Ryan,

I've made the changes. Please take a look.

Zhihao Lou


On Mon, Apr 7, 2014 at 10:59 AM, Ryan Curtin <gth671b at mail.gatech.edu>wrote:

> On Fri, Apr 04, 2014 at 08:25:19PM -0500, Zhihao Lou wrote:
> > Hi Ryan
> >
> > On Fri, Apr 4, 2014 at 4:56 PM, Ryan Curtin <gth671b at mail.gatech.edu>
> wrote:
> >
> > >
> > >  * Could we templatize the probability distribution from which the next
> > >    sample is chosen, and also the cooling schedule?  This way we can
> > >    have an interface that looks like this:
> > >
> > >    template<typename FunctionType, typename DistributionType, typename
> > >        CoolingScheduleType>
> > >    class SA;
> > >
> > >    This will allow more flexibility to the user in the exact way they
> > >    want to use simulated annealing.  You could take what you already
> > >    have, which is a uniform distribution and a geometric cooling
> > >    schedule, and split them out into a UniformDistribution and
> > >    GeometricCoolingSchedule class.
> > >
> >
> > I'm certainly willing to do that. The problem is, however, how to
> > design the interface between these separate classes and the main loop
> > of the annealing.
> >
> > For example, the geometric cooling schedule only needs current
> > temperature to calculate next step's temperature. The same is true for
> > linear and logarithmic schedule, though I don't think anybody should
> > use these two. But the other major class of cooling schedules is
> > adaptive schedules, which usually require additional information like
> > the current value of the cost function (usually for calculating
> > variance etc), and in Lam's schedule (see my comments in MoveControl)
> > requires the boolean whether last proposed move has been accepted. So
> > it is very hard to anticipate what information the cooling schedule
> > will need. The solution I can think of is to pass the SA itself to the
> > schedule, but this is ugly.
>
> You are right, passing the SA object itself is not a clean solution.
> When designing abstractions like this it is often difficult to predict
> what parameters the templated class will need, like you have pointed
> out.  In this case, I think the best idea is to produce an abstraction
> that takes the current temperature and current objective function value
> (it is probably possible to determine the necessary information for
> Lam's schedule by tracking what the objective function was the last time
> the function was called).  If it needs to be modified later for some
> type of schedule we did not anticipate, we can do that later.
>
> > The other thing is that the actual amount of change move in
> > generateMove() is actually a double exponential (Laplace distribution)
> > calculated from uniform intermediate unif. (I probably need more
> > comments there.) The double exponential distribution is related to the
> > move control and the 0.44 value.  I'll suggest not to change this.
>
> That's fine, but if we can templatize that into the DistributionType
> parameter, that would be great.
>
> > >  * Can you comment what is going on in GenerateMove() and MoveControl()
> > >    a little better?  I can sort of follow what you are doing, but it
> > >    takes a while to figure it out, and a couple informative comments
> could
> > >    make it much easier to read.
> > >
> >
> > Sure. I'll going to work on these right now.
>
> Thanks!
>
> > >  * I'd like to add your name to the list of contributors once I work
> > >    this in.  Do you mind if I do this?
> >
> > That will be great!
>
> Ok; I will do that when we finish the design of the optimizer and commit
> it.
>
> Thanks,
>
> Ryan
>
> --
> Ryan Curtin    | "None of your mailman friends can hear you."
> ryan at ratml.org |   - Alpha
>
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mailman.cc.gatech.edu/pipermail/mlpack/attachments/20140414/1a06eb4c/attachment-0003.html>
-------------- next part --------------
Index: src/mlpack/tests/CMakeLists.txt
===================================================================
--- src/mlpack/tests/CMakeLists.txt	(revision 16374)
+++ src/mlpack/tests/CMakeLists.txt	(working copy)
@@ -35,6 +35,7 @@
   pca_test.cpp
   radical_test.cpp
   range_search_test.cpp
+  sa_test.cpp
   save_restore_utility_test.cpp
   sgd_test.cpp
   sort_policy_test.cpp
Index: src/mlpack/tests/sa_test.cpp
===================================================================
--- src/mlpack/tests/sa_test.cpp	(revision 0)
+++ src/mlpack/tests/sa_test.cpp	(revision 0)
@@ -0,0 +1,46 @@
+/*
+ * @file sa_test.cpp
+ * @auther Zhihao Lou
+ *
+ * Test file for SA (simulated annealing).
+ */
+#include <mlpack/core.hpp>
+#include <mlpack/core/optimizers/sa/sa.hpp>
+#include <mlpack/core/optimizers/sa/exponential_schedule.hpp>
+#include <mlpack/core/optimizers/sa/laplace_distribution.hpp>
+#include <mlpack/core/optimizers/lbfgs/test_functions.hpp>
+
+#include <mlpack/core/metrics/ip_metric.hpp>
+#include <mlpack/core/metrics/lmetric.hpp>
+#include <mlpack/core/metrics/mahalanobis_distance.hpp>
+
+#include <boost/test/unit_test.hpp>
+#include "old_boost_test_definitions.hpp"
+
+using namespace std;
+using namespace arma;
+using namespace mlpack;
+using namespace mlpack::optimization;
+using namespace mlpack::optimization::test;
+using namespace mlpack::metric;
+
+BOOST_AUTO_TEST_SUITE(SATest);
+
+BOOST_AUTO_TEST_CASE(GeneralizedRosenbrockTest)
+{
+  size_t dim = 50;
+  GeneralizedRosenbrockFunction f(dim);
+
+  LaplaceDistribution moveDist;
+  ExponentialSchedule schedule(1e-5);
+  SA<GeneralizedRosenbrockFunction, LaplaceDistribution, ExponentialSchedule> 
+      sa(f, moveDist, schedule, 1000.,1000, 100, 1e-9, 3, 20, 0.3, 0.3, 10000000);
+  arma::mat coordinates = f.GetInitialPoint();
+  double result = sa.Optimize(coordinates);
+
+  BOOST_REQUIRE_SMALL(result, 1e-6);
+  for (size_t j = 0; j < dim; ++j)
+      BOOST_REQUIRE_CLOSE(coordinates[j], (double) 1.0, 1e-2);
+}
+
+BOOST_AUTO_TEST_SUITE_END();
Index: src/mlpack/core/optimizers/sa/laplace_distribution.cpp
===================================================================
--- src/mlpack/core/optimizers/sa/laplace_distribution.cpp	(revision 0)
+++ src/mlpack/core/optimizers/sa/laplace_distribution.cpp	(revision 0)
@@ -0,0 +1,22 @@
+/*
+ * @file laplace_distribution.cpp
+ * @author Zhihao Lou
+ *
+ * Implementation of Laplace distribution
+ */
+
+#include <mlpack/core.hpp>
+#include "laplace_distribution.hpp"
+using namespace mlpack;
+using namespace mlpack::optimization;
+double LaplaceDistribution::operator () (const double param)
+{
+  // uniform [-1, 1]
+  double unif = 2.0 * math::Random() - 1.0;
+  // Laplace Distribution with mean 0
+  // x = - param * sign(unif) * log(1 - |unif|)
+  if (unif < 0) // why oh why we don't have a sign function in c++?
+      return (param * std::log(1 + unif));
+  else
+      return (-1.0 * param * std::log(1 - unif));
+}
Index: src/mlpack/core/optimizers/sa/laplace_distribution.hpp
===================================================================
--- src/mlpack/core/optimizers/sa/laplace_distribution.hpp	(revision 0)
+++ src/mlpack/core/optimizers/sa/laplace_distribution.hpp	(revision 0)
@@ -0,0 +1,34 @@
+/*
+ * @file laplace.hpp
+ * @author Zhihao Lou
+ *
+ * Laplace (double exponential) distribution used in SA
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZER_SA_LAPLACE_DISTRIBUTION_HPP
+#define __MLPACK_CORE_OPTIMIZER_SA_LAPLACE_DISTRIBUTION_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/* 
+ * The Laplace distribution centered at 0 has pdf
+ * \f[
+ * f(x|\theta) = \frac{1}{2\theta}\exp\left(-\frac{|x|}{\theta}\right)
+ * \f]
+ * given scale parameter \f$\theta\f$.
+ */
+class LaplaceDistribution
+{
+ public:
+  //! Nothing to do for the constructor
+  LaplaceDistribution(){}
+  //! Return random value from Laplace distribution with parameter param
+  double operator () (const double param);
+
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Index: src/mlpack/core/optimizers/sa/exponential_schedule.hpp
===================================================================
--- src/mlpack/core/optimizers/sa/exponential_schedule.hpp	(revision 0)
+++ src/mlpack/core/optimizers/sa/exponential_schedule.hpp	(revision 0)
@@ -0,0 +1,52 @@
+/*
+ * @file exponential_schedule.hpp
+ * @author Zhihao Lou
+ *
+ * Exponential (geometric) cooling schedule used in SA
+ */
+
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_EXPONENTIAL_SCHEDULE_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_EXPONENTIAL_SCHEDULE_HPP
+
+namespace mlpack {
+namespace optimization {
+
+/* 
+ * The exponential cooling schedule cools the temperature T at every step
+ * \f[
+ * T_{n+1}=(1-\lambda)T_{n}
+ * \f]
+ * where \f$ 0<\lambda<1 \f$ is the cooling speed. The smaller \f$ \lambda \f$
+ * is, the slower the cooling speed, and better the final result will be. Some
+ * literature uses \f$ \alpha=(-1\lambda) \f$ instead. In practice, \f$ \alpha \f$
+ * is very close to 1 and will be awkward to input (e.g. alpha=0.999999 vs
+ * lambda=1e-6).
+ */
+class ExponentialSchedule
+{
+ public:
+  /* 
+   * Construct the ExponentialSchedule with the given parameter
+   *
+   * @param lambda Cooling speed
+   */
+  ExponentialSchedule(const double lambda = 0.001) : lambda(lambda){};
+
+  //! returns the next temperature given current status
+  double nextTemperature(const double currentTemperate, const double)
+  {return (1-lambda) * currentTemperate;}
+
+  //! Get the cooling speed lambda
+  double Lambda() const {return lambda;}
+  //! Modify the cooling speed lambda
+  double& Lambda() {return lambda;}
+ private:
+  double lambda;
+
+
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#endif
Index: src/mlpack/core/optimizers/sa/sa.hpp
===================================================================
--- src/mlpack/core/optimizers/sa/sa.hpp	(revision 0)
+++ src/mlpack/core/optimizers/sa/sa.hpp	(revision 0)
@@ -0,0 +1,173 @@
+/*
+ * @file sa.hpp
+ * @author Zhihao Lou
+ *
+ * Simulated Annealing (SA)
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_SA_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_SA_HPP
+
+namespace mlpack {
+namespace optimization {
+/* 
+ * Simulated Annealing is an stochastic optimization algorithm which is able to
+ * deliver near-optimal results quickly without knowing the gradient of the
+ * function being optimized. It has unique hill climbing capability that make
+ * it less vulnerable to local minima. This implementation uses exponential
+ * cooling schedule and feedback move control.
+ *
+ * The algorithm keeps the temperature at initial temperature for initMove
+ * steps to get ride of the dependency of initial condition. After that, it
+ * cools every step until the system considered frozen or maxIterations is 
+ * reached.
+ *
+ * Every step SA only perturb one parameter at a time. The process that SA
+ * perturbed all parameters in a problem is called a sweep. Every moveCtrlSweep
+ * the algorithm does feedback move control to change the average move size
+ * depending on the responsiveness of each parameter. Parameter gain controls
+ * the proportion of the feedback control.
+ *
+ * The system is considered "frozen" when its score failed to change more then
+ * tolerance for consecutive maxToleranceSweep sweeps.
+ *
+ * For SA to work, a function must implement the following methods:
+ *   double Evaluate(const arma::mat& coordinates);
+ *   arma::mat& GetInitialPoint();
+ *
+ * In additional, a move generation distribution with overloaded operator():
+ *   double operator () (const double param);
+ * which returns a random value from the distribution given parameter param,
+ * and a cooling schedule with method:
+ *   doulbe nextTemperature(const double currentTemperature, const double currentValue);
+ * which returns the next temperature given current temperature and the value
+ * of the function being optimized.
+ *
+ * @tparam FunctionType objective function type to be minimized.
+ * @tparam MoveDistributionType distribution type for move generation
+ * @tparam CoolingScheduleType type for cooling schedule
+ */
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+class SA
+{
+ public:
+  /* 
+   * Construct the SA optimizer with the given function and paramters.
+   *
+   * @param function Function to be minimized.
+   * @param moveDistribution Distribution for move generation
+   * @param coolingSchedule Cooling schedule
+   * @param initT Initial temperature.
+   * @param initMoves Iterations without changing temperature.
+   * @param moveCtrlSweep Sweeps per move control.
+   * @param tolerance Tolerance to consider system frozen.
+   * @param maxToleranceSweep Maximum sweeps below tolerance to consider system frozen.
+   * @param maxMoveCoef Maximum move size.
+   * @param initMoveCoef Initial move size.
+   * @param gain Proportional control in feedback move control.
+   * @param maxIterations Maximum number of iterations allowed (0 indicates no limit).
+   */
+  SA(FunctionType& function,
+     MoveDistributionType& moveDistribution,
+     CoolingScheduleType& coolingSchedule,
+     const double initT = 10000.,
+     const size_t initMoves = 1000,
+     const size_t moveCtrlSweep = 100,
+     const double tolerance = 1e-5,
+     const size_t maxToleranceSweep = 3,
+     const double maxMoveCoef = 20,
+     const double initMoveCoef = 0.3,
+     const double gain = 0.3,
+     const size_t maxIterations = 1000000);
+  /* 
+   * Optimize the given function using simulated annealing. The given starting
+   * point will be modified to store the finishing point of the algorithm, and
+   * the final objective value is returned.
+   *
+   * @param iterate Starting point (will be modified).
+   * @return Objective value of the final point.
+   */
+  double Optimize(arma::mat& iterate);
+
+  //! Get the instantiated function to be optimized.
+  const FunctionType& Function() const {return function;}
+  //! Modify the instantiated function.
+  FunctionType& Function() {return function;}
+
+  //! Get the temperature.
+  double Temperature() const {return T;}
+  //! Modify the temperature.
+  double& Temperature() {return T;}
+
+  //! Get the initial moves.
+  size_t InitMoves() const {return initMoves;}
+  //! Modify the initial moves.
+  size_t& InitMoves() {return initMoves;}
+
+  //! Get sweeps per move control.
+  size_t MoveCtrlSweep() const {return moveCtrlSweep;}
+  //! Modify sweeps per move control.
+  size_t& MoveCtrlSweep() {return moveCtrlSweep;}
+
+  //! Get the tolerance.
+  double Tolerance() const {return tolerance;}
+  //! Modify the tolerance.
+  double& Tolerance() {return tolerance;}
+
+  //! Get the maxToleranceSweep.
+  size_t MaxToleranceSweep() const {return maxToleranceSweep;}
+  //! Modify the maxToleranceSweep.
+  size_t& MaxToleranceSweep() {return maxToleranceSweep;}
+
+  //! Get the gain.
+  double Gain() const {return gain;}
+  //! Modify the gain.
+  double& Gain() {return gain;}
+
+  //! Get the maxIterations.
+  size_t MaxIterations() const {return maxIterations;}
+  //! Modify the maxIterations.
+  size_t& MaxIterations() {return maxIterations;}
+
+  //! Get Maximum move size of each parameter
+  arma::mat MaxMove() const {return maxMove;}
+  //! Modify maximum move size of each parameter
+  arma::mat& MaxMove() {return maxMove;}
+
+  //! Get move size of each parameter
+  arma::mat MoveSize() const {return moveSize;}
+  //! Modify  move size of each parameter
+  arma::mat& MoveSize() {return moveSize;}
+
+  std::string ToString() const;
+ private:
+  FunctionType &function;
+  MoveDistributionType &moveDistribution;
+  CoolingScheduleType &coolingSchedule;
+  double T;
+  size_t initMoves;
+  size_t moveCtrlSweep;
+  double tolerance;
+  size_t maxToleranceSweep;
+  double gain;
+  size_t maxIterations;
+  arma::mat maxMove;
+  arma::mat moveSize;
+
+
+  // following variables are initialized inside Optimize
+  arma::mat accept;
+  double energy; 
+  size_t idx;
+  size_t nVars;
+  size_t sweepCounter;
+
+  void GenerateMove(arma::mat& iterate);
+  void MoveControl(size_t nMoves);
+};
+
+}; // namespace optimization
+}; // namespace mlpack
+
+#include "sa_impl.hpp"
+
+#endif
Index: src/mlpack/core/optimizers/sa/sa_impl.hpp
===================================================================
--- src/mlpack/core/optimizers/sa/sa_impl.hpp	(revision 0)
+++ src/mlpack/core/optimizers/sa/sa_impl.hpp	(revision 0)
@@ -0,0 +1,196 @@
+/*
+ * @file sa_impl.hpp
+ * @auther Zhihao Lou
+ *
+ * The implementation of the SA optimizer.
+ */
+#ifndef __MLPACK_CORE_OPTIMIZERS_SA_SA_IMPL_HPP
+#define __MLPACK_CORE_OPTIMIZERS_SA_SA_IMPL_HPP
+
+namespace mlpack {
+namespace optimization {
+
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+SA<FunctionType, MoveDistributionType, CoolingScheduleType>::
+SA(FunctionType& function,
+                     MoveDistributionType& moveDistribution,
+                     CoolingScheduleType& coolingSchedule,
+                     const double initT,
+                     const size_t initMoves,
+                     const size_t moveCtrlSweep,
+                     const double tolerance,
+                     const size_t maxToleranceSweep,
+                     const double maxMoveCoef,
+                     const double initMoveCoef,
+                     const double gain,
+                     const size_t maxIterations) : 
+    function(function),
+    moveDistribution(moveDistribution),
+    coolingSchedule(coolingSchedule),
+    T(initT),
+    initMoves(initMoves),
+    moveCtrlSweep(moveCtrlSweep),
+    tolerance(tolerance),
+    maxToleranceSweep(maxToleranceSweep),
+    gain(gain),
+    maxIterations(maxIterations)
+{
+  const size_t rows = function.GetInitialPoint().n_rows;
+  const size_t cols = function.GetInitialPoint().n_cols;
+
+  maxMove.set_size(rows, cols);
+  maxMove.fill(maxMoveCoef);
+  moveSize.set_size(rows, cols);
+  moveSize.fill(initMoveCoef);
+  accept.zeros(rows, cols);
+}
+
+//! Optimize the function (minimize).
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+double SA<FunctionType, MoveDistributionType, CoolingScheduleType>::Optimize(arma::mat &iterate)
+{
+  const size_t rows = function.GetInitialPoint().n_rows;
+  const size_t cols = function.GetInitialPoint().n_cols;
+  size_t i;
+  size_t frozenCount = 0;
+  energy = function.Evaluate(iterate);
+  size_t old_energy = energy;
+  math::RandomSeed(std::time(NULL));
+
+  nVars = rows * cols;
+  idx = 0;
+  sweepCounter = 0;
+  accept.zeros();
+
+  // Initial Moves to get ride of dependency of initial states
+  for (i = 0; i < initMoves; ++i) 
+  {
+      GenerateMove(iterate);
+  }
+
+  // Iterating and cooling
+  for (i = 0; i != maxIterations; ++i)
+  {
+    old_energy = energy;
+    GenerateMove(iterate);
+    T = coolingSchedule.nextTemperature(T, energy);
+    if (std::abs(energy - old_energy) < tolerance)
+      ++ frozenCount;
+    else
+      frozenCount = 0;
+    if (frozenCount >= maxToleranceSweep * nVars)
+    {
+      Log::Info << "SA: minimized within tolerance " << tolerance
+          << " for " << maxToleranceSweep << " times; terminating "
+          << "optimization." << std::endl;
+      return energy;
+    }
+  }
+  Log::Info << "SA: maximum iterations (" << maxIterations << ") reached; "
+      << "terminating optimization." << std::endl;
+  return energy;
+}
+
+/* 
+ * GenerateMove proposes a move on element iterate(idx), and determine
+ * it that move is acceptable or not according to the Metropolis criterion.
+ * After that it increments idx so next call will make a move on next
+ * parameters. When all elements of the state has been moved (a sweep), it 
+ * resets idx and increments sweepCounter. When sweepCounter reaches
+ * moveCtrlSweep, it performs moveControl and resets sweepCounter.
+ */
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+void SA<FunctionType, MoveDistributionType, CoolingScheduleType>::GenerateMove(arma::mat& iterate)
+{
+  double prevEnergy = energy;
+  double prevValue = iterate(idx);
+  double move = moveDistribution(moveSize(idx));
+  iterate(idx) += move;
+  energy = function.Evaluate(iterate);
+  // according to Metropolis criterion, accept the move in probability
+  // min{1,exp(-(E_new - E_old)/T)}
+  double xi = math::Random();
+  double delta = energy - prevEnergy;
+  double criterion = std::exp(-delta / T);
+  if (delta <= 0. || criterion > xi)
+  {
+    accept(idx) += 1.;
+  } 
+  else // reject the move; restore previous state
+  {
+    iterate(idx) = prevValue;
+    energy = prevEnergy;
+  }
+  ++ idx;
+  if (idx == nVars) // end of 1 sweep; wraps around idx and increments sweepCounter
+  { 
+    idx = 0;
+    ++ sweepCounter;
+  }
+  if (sweepCounter == moveCtrlSweep) // do MoveControl
+  {
+    MoveControl(moveCtrlSweep);
+    sweepCounter = 0;
+  }
+}
+
+
+/*
+ * MoveControl() uses a proportional feedback control to determine the size
+ * parameter to pass to the move generation distribution. The target of such
+ * move control is to make the acceptance ratio, accept/nMoves, be as close to
+ * 0.44 as possible. Generally speaking, the lager the move size is, the larger
+ * the function value change of the move will be, and less likely such move be
+ * accepted by the Metropolis criterion. Thus, the move size is controlled by
+ *
+ * log(moveSize) = log(moveSize) + gain * (accept/nMoves - target)
+ *
+ * For more theory and the mysterious 0.44 value, see Jimmy K.-C. Lam and
+ * Jean-Marc Delosme. An efficient simulated annealing schedule: derivation.
+ * Technical Report 8816, Yale University, 1988
+ */
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+void SA<FunctionType, MoveDistributionType, CoolingScheduleType>::MoveControl(size_t nMoves)
+{
+  arma::mat target;
+  target.copy_size(accept);
+  target.fill(0.44);
+  moveSize = arma::log(moveSize);
+  moveSize += gain * (accept / (double) nMoves - target);
+  moveSize = arma::exp(moveSize);
+  // To avoid the use of element-wise arma::min(), which is only available in
+  // Armadillo after v3.930, here uses a for loop instead.
+  for (size_t i = 0; i < nVars; ++i)
+  {
+    moveSize(i) = moveSize(i) > maxMove(i) ? maxMove(i) : moveSize(i);
+  }
+  accept.zeros();
+}
+
+template<typename FunctionType, typename MoveDistributionType, typename CoolingScheduleType>
+std::string SA<FunctionType, MoveDistributionType, CoolingScheduleType>::ToString() const
+{
+  std::ostringstream convert;
+  convert << "SA [" << this << "]" << std::endl;
+  convert << "  Function:" << std::endl;
+  convert << util::Indent(function.ToString(),2);
+  convert << "  Move Distribution:" << std::endl;
+  convert << util::Indent(moveDistribution.ToString(),2);
+  convert << "  Cooling Schedule:" << std::endl;
+  convert << util::Indent(coolingSchedule.ToString(),2);
+  convert << "  Temperature: " << T << std::endl;
+  convert << "  Initial moves: " << initMoves << std::endl;
+  convert << "  Sweeps per move control: " << moveCtrlSweep << std::endl;
+  convert << "  Tolerance: " << tolerance << std::endl;
+  convert << "  Maximum sweeps below tolerance: " << maxToleranceSweep
+      << std::endl;
+  convert << "  Move control gain: " << gain << std::endl;
+  convert << "  Maximum iterations: " << maxIterations << std::endl;
+  return convert.str();
+}
+
+}; // namespace optimization
+}; // namespace mlpack
+
+
+#endif
Index: src/mlpack/core/optimizers/sa/CMakeLists.txt
===================================================================
--- src/mlpack/core/optimizers/sa/CMakeLists.txt	(revision 0)
+++ src/mlpack/core/optimizers/sa/CMakeLists.txt	(revision 0)
@@ -0,0 +1,14 @@
+set(SOURCES
+  sa.hpp
+  sa_impl.hpp
+  laplace_distribution.hpp
+  laplace_distribution.cpp
+  exponential_schedule.hpp
+)
+
+set(DIR_SRCS)
+foreach(file ${SOURCES})
+  set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
+endforeach()
+
+set(MLPACK_SRCS ${MLPACK_SRCS} ${DIR_SRCS} PARENT_SCOPE)
Index: src/mlpack/core/optimizers/CMakeLists.txt
===================================================================
--- src/mlpack/core/optimizers/CMakeLists.txt	(revision 16374)
+++ src/mlpack/core/optimizers/CMakeLists.txt	(working copy)
@@ -2,6 +2,7 @@
   aug_lagrangian
   lbfgs
   lrsdp
+  sa
   sgd
 )
 


More information about the mlpack mailing list