greedy_policy.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP
14 #define MLPACK_METHODS_RL_POLICY_GREEDY_POLICY_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace rl {
20 
30 template <typename EnvironmentType>
32 {
33  public:
35  using ActionType = typename EnvironmentType::Action;
36 
48  GreedyPolicy(const double initialEpsilon,
49  const size_t annealInterval,
50  const double minEpsilon,
51  const double decayRate = 1.0) :
52  epsilon(initialEpsilon),
53  minEpsilon(minEpsilon),
54  delta(((initialEpsilon - minEpsilon) * decayRate) / annealInterval)
55  { /* Nothing to do here. */ }
56 
64  ActionType Sample(const arma::colvec& actionValue, bool deterministic = false)
65  {
66  double exploration = math::Random();
67 
68  // Select the action randomly.
69  if (!deterministic && exploration < epsilon)
70  return static_cast<ActionType>(math::RandInt(ActionType::size));
71 
72  // Select the action greedily.
73  return static_cast<ActionType>(
74  arma::as_scalar(arma::find(actionValue == actionValue.max(), 1)));
75  }
76 
80  void Anneal()
81  {
82  epsilon -= delta;
83  epsilon = std::max(minEpsilon, epsilon);
84  }
85 
89  const double& Epsilon() const { return epsilon; }
90 
91  private:
93  double epsilon;
94 
96  double minEpsilon;
97 
99  double delta;
100 };
101 
102 } // namespace rl
103 } // namespace mlpack
104 
105 #endif
typename EnvironmentType::Action ActionType
Convenient typedef for action.
.hpp
Definition: add_to_po.hpp:21
Implementation for epsilon greedy policy.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Anneal()
Exploration probability will anneal at each step.
const double & Epsilon() const
double Random()
Generates a uniform random number between 0 and 1.
Definition: random.hpp:78
int RandInt(const int hiExclusive)
Generates a uniform random integer.
Definition: random.hpp:105
GreedyPolicy(const double initialEpsilon, const size_t annealInterval, const double minEpsilon, const double decayRate=1.0)
Constructor for epsilon greedy policy class.
ActionType Sample(const arma::colvec &actionValue, bool deterministic=false)
Sample an action based on given action values.