double_pole_cart.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace rl {
20 
28 {
29  public:
35  class State
36  {
37  public:
41  State() : data(dimension)
42  { /* Nothing to do here. */ }
43 
49  State(const arma::colvec& data) : data(data)
50  { /* Nothing to do here */ }
51 
53  arma::colvec Data() const { return data; }
55  arma::colvec& Data() { return data; }
56 
58  double Position() const { return data[0]; }
60  double& Position() { return data[0]; }
61 
63  double Velocity() const { return data[1]; }
65  double& Velocity() { return data[1]; }
66 
68  double Angle(const size_t i) const { return data[2 * i]; }
70  double& Angle(const size_t i) { return data[2 * i]; }
71 
73  double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
75  double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
76 
78  const arma::colvec& Encode() const { return data; }
79 
81  static constexpr size_t dimension = 6;
82 
83  private:
85  arma::colvec data;
86  };
87 
91  enum Action
92  {
95 
96  // Track the size of the action space.
98  };
99 
118  DoublePoleCart(const double m1 = 0.1,
119  const double m2 = 0.01,
120  const double l1 = 0.5,
121  const double l2 = 0.05,
122  const double gravity = 9.8,
123  const double massCart = 1.0,
124  const double forceMag = 10.0,
125  const double tau = 0.02,
126  const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
127  const double xThreshold = 2.4,
128  const double doneReward = 0.0,
129  const size_t maxSteps = 0) :
130  m1(m1),
131  m2(m2),
132  l1(l1),
133  l2(l2),
134  gravity(gravity),
135  massCart(massCart),
136  forceMag(forceMag),
137  tau(tau),
138  thetaThresholdRadians(thetaThresholdRadians),
139  xThreshold(xThreshold),
140  doneReward(doneReward),
141  maxSteps(maxSteps),
142  stepsPerformed(0)
143  { /* Nothing to do here */ }
144 
154  double Sample(const State& state,
155  const Action& action,
156  State& nextState)
157  {
158  // Update the number of steps performed.
159  stepsPerformed++;
160 
161  arma::vec dydx(6, arma::fill::zeros);
162  dydx[0] = state.Velocity();
163  dydx[2] = state.AngularVelocity(1);
164  dydx[4] = state.AngularVelocity(2);
165  Dsdt(state, action, dydx);
166  RK4(state, action, dydx, nextState);
167 
168  // Check if the episode has terminated.
169  bool done = IsTerminal(nextState);
170 
171  // Do not reward agent if it failed.
172  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
173  return doneReward;
174  else if (done)
175  return 0;
176 
181  return 1.0;
182  }
183 
192  void Dsdt(const State& state,
193  const Action& action,
194  arma::vec& dydx)
195  {
196  double totalForce = action ? forceMag : -forceMag;
197  double totalMass = massCart;
198  double omega1 = state.AngularVelocity(1);
199  double omega2 = state.AngularVelocity(2);
200  double sinTheta1 = std::sin(state.Angle(1));
201  double sinTheta2 = std::sin(state.Angle(2));
202  double cosTheta1 = std::cos(state.Angle(1));
203  double cosTheta2 = std::cos(state.Angle(2));
204 
205  // Calculate total effective force.
206  totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
207  std::sin(2 * state.Angle(1));
208  totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
209  std::sin(2 * state.Angle(2));
210 
211  // Calculate total effective mass.
212  totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
213  totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
214 
215  // Calculate acceleration.
216  double xAcc = totalForce / totalMass;
217  dydx[1] = xAcc;
218 
219  // Calculate angular acceleration.
220  dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
221  dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
222  }
223 
233  void RK4(const State& state,
234  const Action& action,
235  arma::vec& dydx,
236  State& nextState)
237  {
238  const double hh = tau * 0.5;
239  const double h6 = tau / 6;
240  arma::vec yt(6);
241  arma::vec dyt(6);
242  arma::vec dym(6);
243 
244  yt = state.Data() + (hh * dydx);
245  Dsdt(State(yt), action, dyt);
246  dyt[0] = yt[1];
247  dyt[2] = yt[3];
248  dyt[4] = yt[5];
249  yt = state.Data() + (hh * dyt);
250 
251  Dsdt(State(yt), action, dym);
252  dym[0] = yt[1];
253  dym[2] = yt[3];
254  dym[4] = yt[5];
255  yt = state.Data() + (tau * dym);
256  dym += dyt;
257 
258  Dsdt(State(yt), action, dyt);
259  dyt[0] = yt[1];
260  dyt[2] = yt[3];
261  dyt[4] = yt[5];
262  nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
263  }
264 
273  double Sample(const State& state, const Action& action)
274  {
275  State nextState;
276  return Sample(state, action, nextState);
277  }
278 
285  {
286  stepsPerformed = 0;
287  return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
288  }
289 
296  bool IsTerminal(const State& state) const
297  {
298  if (maxSteps != 0 && stepsPerformed >= maxSteps)
299  {
300  Log::Info << "Episode terminated due to the maximum number of steps"
301  "being taken.";
302  return true;
303  }
304  if (std::abs(state.Position()) > xThreshold)
305  {
306  Log::Info << "Episode terminated due to cart crossing threshold";
307  return true;
308  }
309  if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
310  std::abs(state.Angle(2)) > thetaThresholdRadians)
311  {
312  Log::Info << "Episode terminated due to pole falling";
313  return true;
314  }
315  return false;
316  }
317 
319  size_t StepsPerformed() const { return stepsPerformed; }
320 
322  size_t MaxSteps() const { return maxSteps; }
324  size_t& MaxSteps() { return maxSteps; }
325 
326  private:
328  double m1;
329 
331  double m2;
332 
334  double l1;
335 
337  double l2;
338 
340  double gravity;
341 
343  double massCart;
344 
346  double forceMag;
347 
349  double tau;
350 
352  double thetaThresholdRadians;
353 
355  double xThreshold;
356 
358  double doneReward;
359 
361  size_t maxSteps;
362 
364  size_t stepsPerformed;
365 };
366 
367 } // namespace rl
368 } // namespace mlpack
369 
370 #endif
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
.hpp
Definition: add_to_po.hpp:21
Implementation of Double Pole Cart Balancing task.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Position() const
Get the position of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
size_t StepsPerformed() const
Get the number of steps performed.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
Definition: log.hpp:84
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
arma::colvec & Data()
Modify the internal representation of the state.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Double Pole Cart.
State()
Construct a state instance.
DoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
double Velocity() const
Get the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
Action
Implementation of action of Double Pole Cart.
arma::colvec Data() const
Get the internal representation of the state.