cart_pole.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
16 #define MLPACK_METHODS_RL_ENVIRONMENT_CART_POLE_HPP
17 
18 #include <mlpack/prereqs.hpp>
19 
20 namespace mlpack {
21 namespace rl {
22 
26 class CartPole
27 {
28  public:
33  class State
34  {
35  public:
39  State() : data(dimension)
40  { /* Nothing to do here. */ }
41 
47  State(const arma::colvec& data) : data(data)
48  { /* Nothing to do here */ }
49 
51  arma::colvec& Data() { return data; }
52 
54  double Position() const { return data[0]; }
56  double& Position() { return data[0]; }
57 
59  double Velocity() const { return data[1]; }
61  double& Velocity() { return data[1]; }
62 
64  double Angle() const { return data[2]; }
66  double& Angle() { return data[2]; }
67 
69  double AngularVelocity() const { return data[3]; }
71  double& AngularVelocity() { return data[3]; }
72 
74  const arma::colvec& Encode() const { return data; }
75 
77  static constexpr size_t dimension = 4;
78 
79  private:
81  arma::colvec data;
82  };
83 
87  class Action
88  {
89  public:
90  enum actions
91  {
94  };
95  // To store the action.
97 
98  // Track the size of the action space.
99  static const size_t size = 2;
100  };
101 
117  CartPole(const size_t maxSteps = 200,
118  const double gravity = 9.8,
119  const double massCart = 1.0,
120  const double massPole = 0.1,
121  const double length = 0.5,
122  const double forceMag = 10.0,
123  const double tau = 0.02,
124  const double thetaThresholdRadians = 12 * 2 * 3.1416 / 360,
125  const double xThreshold = 2.4,
126  const double doneReward = 1.0) :
127  maxSteps(maxSteps),
128  gravity(gravity),
129  massCart(massCart),
130  massPole(massPole),
131  totalMass(massCart + massPole),
132  length(length),
133  poleMassLength(massPole * length),
134  forceMag(forceMag),
135  tau(tau),
136  thetaThresholdRadians(thetaThresholdRadians),
137  xThreshold(xThreshold),
138  doneReward(doneReward),
139  stepsPerformed(0)
140  { /* Nothing to do here */ }
141 
151  double Sample(const State& state,
152  const Action& action,
153  State& nextState)
154  {
155  // Update the number of steps performed.
156  stepsPerformed++;
157 
158  // Calculate acceleration.
159  double force = action.action ? forceMag : -forceMag;
160  double cosTheta = std::cos(state.Angle());
161  double sinTheta = std::sin(state.Angle());
162  double temp = (force + poleMassLength * state.AngularVelocity() *
163  state.AngularVelocity() * sinTheta) / totalMass;
164  double thetaAcc = (gravity * sinTheta - cosTheta * temp) /
165  (length * (4.0 / 3.0 - massPole * cosTheta * cosTheta / totalMass));
166  double xAcc = temp - poleMassLength * thetaAcc * cosTheta / totalMass;
167 
168  // Update states.
169  nextState.Position() = state.Position() + tau * state.Velocity();
170  nextState.Velocity() = state.Velocity() + tau * xAcc;
171  nextState.Angle() = state.Angle() + tau * state.AngularVelocity();
172  nextState.AngularVelocity() = state.AngularVelocity() + tau * thetaAcc;
173 
174  // Check if the episode has terminated.
175  bool done = IsTerminal(nextState);
176 
177  // Do not reward agent if it failed.
178  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
179  return doneReward;
180 
185  return 1.0;
186  }
187 
196  double Sample(const State& state, const Action& action)
197  {
198  State nextState;
199  return Sample(state, action, nextState);
200  }
201 
208  {
209  stepsPerformed = 0;
210  return State((arma::randu<arma::colvec>(4) - 0.5) / 10.0);
211  }
212 
219  bool IsTerminal(const State& state) const
220  {
221  if (maxSteps != 0 && stepsPerformed >= maxSteps)
222  {
223  Log::Info << "Episode terminated due to the maximum number of steps"
224  "being taken.";
225  return true;
226  }
227  else if (std::abs(state.Position()) > xThreshold ||
228  std::abs(state.Angle()) > thetaThresholdRadians)
229  {
230  Log::Info << "Episode terminated due to agent failing.";
231  return true;
232  }
233  return false;
234  }
235 
237  size_t StepsPerformed() const { return stepsPerformed; }
238 
240  size_t MaxSteps() const { return maxSteps; }
242  size_t& MaxSteps() { return maxSteps; }
243 
244  private:
246  size_t maxSteps;
247 
249  double gravity;
250 
252  double massCart;
253 
255  double massPole;
256 
258  double totalMass;
259 
261  double length;
262 
264  double poleMassLength;
265 
267  double forceMag;
268 
270  double tau;
271 
273  double thetaThresholdRadians;
274 
276  double xThreshold;
277 
279  double doneReward;
280 
282  size_t stepsPerformed;
283 };
284 
285 } // namespace rl
286 } // namespace mlpack
287 
288 #endif
double Velocity() const
Get the velocity.
Definition: cart_pole.hpp:59
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: cart_pole.hpp:47
double AngularVelocity() const
Get the angular velocity.
Definition: cart_pole.hpp:69
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Definition: iterator.hpp:29
double & Velocity()
Modify the velocity.
Definition: cart_pole.hpp:61
Linear algebra utility functions, generally performed on matrices or vectors.
double Sample(const State &state, const Action &action)
Dynamics of Cart Pole.
Definition: cart_pole.hpp:196
Implementation of action of Cart Pole.
Definition: cart_pole.hpp:87
State()
Construct a state instance.
Definition: cart_pole.hpp:39
constexpr T && forward(remove_reference_t< T > &t) noexcept
Definition: utility.hpp:27
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Cart Pole instance.
Definition: cart_pole.hpp:151
CartPole(const size_t maxSteps=200, const double gravity=9.8, const double massCart=1.0, const double massPole=0.1, const double length=0.5, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=12 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=1.0)
Construct a Cart Pole instance using the given constants.
Definition: cart_pole.hpp:117
bool IsTerminal(const State &state) const
This function checks if the cart has reached the terminal state.
Definition: cart_pole.hpp:219
double Position() const
Get the position.
Definition: cart_pole.hpp:54
Implementation of the state of Cart Pole.
Definition: cart_pole.hpp:33
size_t & MaxSteps()
Set the maximum number of steps allowed.
Definition: cart_pole.hpp:242
double Angle() const
Get the angle.
Definition: cart_pole.hpp:64
double & Angle()
Modify the angle.
Definition: cart_pole.hpp:66
double & Position()
Modify the position.
Definition: cart_pole.hpp:56
static util::PrefixedOutStream Info
Definition: log.hpp:93
double & AngularVelocity()
Modify the angular velocity.
Definition: cart_pole.hpp:71
arma::colvec & Data()
Modify the internal representation of the state.
Definition: cart_pole.hpp:51
static constexpr size_t dimension
Dimension of the encoded state.
Definition: cart_pole.hpp:77
const arma::colvec & Encode() const
Encode the state to a column vector.
Definition: cart_pole.hpp:74
Implementation of Cart Pole task.
Definition: cart_pole.hpp:26
size_t StepsPerformed() const
Get the number of steps performed.
Definition: cart_pole.hpp:237
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
Definition: cart_pole.hpp:207
size_t MaxSteps() const
Get the maximum number of steps allowed.
Definition: cart_pole.hpp:240