continuous_double_pole_cart.hpp
Go to the documentation of this file.
1 
14 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
15 #define MLPACK_METHODS_RL_ENVIRONMENT_CONTINUOUS_DOUBLE_POLE_CART_HPP
16 
17 #include <mlpack/prereqs.hpp>
18 
19 namespace mlpack {
20 namespace rl {
21 
29 {
30  public:
36  class State
37  {
38  public:
42  State() : data(dimension)
43  { /* Nothing to do here. */ }
44 
50  State(const arma::colvec& data) : data(data)
51  { /* Nothing to do here */ }
52 
54  arma::colvec Data() const { return data; }
56  arma::colvec& Data() { return data; }
57 
59  double Position() const { return data[0]; }
61  double& Position() { return data[0]; }
62 
64  double Velocity() const { return data[1]; }
66  double& Velocity() { return data[1]; }
67 
69  double Angle(const size_t i) const { return data[2 * i]; }
71  double& Angle(const size_t i) { return data[2 * i]; }
72 
74  double AngularVelocity(const size_t i) const { return data[2 * i + 1]; }
76  double& AngularVelocity(const size_t i) { return data[2 * i + 1]; }
77 
79  const arma::colvec& Encode() const { return data; }
80 
82  static constexpr size_t dimension = 6;
83 
84  private:
86  arma::colvec data;
87  };
88 
92  struct Action
93  {
94  double action[1];
95  // Storing degree of freedom
96  const int size = 1;
97  };
98 
116  ContinuousDoublePoleCart(const double m1 = 0.1,
117  const double m2 = 0.01,
118  const double l1 = 0.5,
119  const double l2 = 0.05,
120  const double gravity = 9.8,
121  const double massCart = 1.0,
122  const double forceMag = 10.0,
123  const double tau = 0.02,
124  const double thetaThresholdRadians = 36 * 2 *
125  3.1416 / 360,
126  const double xThreshold = 2.4,
127  const double doneReward = 0.0,
128  const size_t maxSteps = 0) :
129  m1(m1),
130  m2(m2),
131  l1(l1),
132  l2(l2),
133  gravity(gravity),
134  massCart(massCart),
135  forceMag(forceMag),
136  tau(tau),
137  thetaThresholdRadians(thetaThresholdRadians),
138  xThreshold(xThreshold),
139  doneReward(doneReward),
140  maxSteps(maxSteps),
141  stepsPerformed(0)
142  { /* Nothing to do here */ }
143 
153  double Sample(const State& state,
154  const Action& action,
155  State& nextState)
156  {
157  // Update the number of steps performed.
158  stepsPerformed++;
159 
160  arma::vec dydx(6, arma::fill::zeros);
161  dydx[0] = state.Velocity();
162  dydx[2] = state.AngularVelocity(1);
163  dydx[4] = state.AngularVelocity(2);
164  Dsdt(state, action, dydx);
165  RK4(state, action, dydx, nextState);
166 
167  // Check if the episode has terminated.
168  bool done = IsTerminal(nextState);
169 
170  // Do not reward agent if it failed.
171  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
172  return doneReward;
173  else if (done)
174  return 0;
175 
180  return 1.0;
181  }
182 
191  void Dsdt(const State& state,
192  const Action& action,
193  arma::vec& dydx)
194  {
195  double totalForce = action.action[0];
196  double totalMass = massCart;
197  double omega1 = state.AngularVelocity(1);
198  double omega2 = state.AngularVelocity(2);
199  double sinTheta1 = std::sin(state.Angle(1));
200  double sinTheta2 = std::sin(state.Angle(2));
201  double cosTheta1 = std::cos(state.Angle(1));
202  double cosTheta2 = std::cos(state.Angle(2));
203 
204  // Calculate total effective force.
205  totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
206  std::sin(2 * state.Angle(1));
207  totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
208  std::sin(2 * state.Angle(2));
209 
210  // Calculate total effective mass.
211  totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
212  totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
213 
214  // Calculate acceleration.
215  double xAcc = totalForce / totalMass;
216  dydx[1] = xAcc;
217 
218  // Calculate angular acceleration.
219  dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
220  dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
221  }
222 
232  void RK4(const State& state,
233  const Action& action,
234  arma::vec& dydx,
235  State& nextState)
236  {
237  const double hh = tau * 0.5;
238  const double h6 = tau / 6;
239  arma::vec yt(6);
240  arma::vec dyt(6);
241  arma::vec dym(6);
242 
243  yt = state.Data() + (hh * dydx);
244  Dsdt(State(yt), action, dyt);
245  dyt[0] = yt[1];
246  dyt[2] = yt[3];
247  dyt[4] = yt[5];
248  yt = state.Data() + (hh * dyt);
249 
250  Dsdt(State(yt), action, dym);
251  dym[0] = yt[1];
252  dym[2] = yt[3];
253  dym[4] = yt[5];
254  yt = state.Data() + (tau * dym);
255  dym += dyt;
256 
257  Dsdt(State(yt), action, dyt);
258  dyt[0] = yt[1];
259  dyt[2] = yt[3];
260  dyt[4] = yt[5];
261  nextState.Data() = state.Data() + h6 * (dydx + dyt + 2 * dym);
262  }
263 
272  double Sample(const State& state, const Action& action)
273  {
274  State nextState;
275  return Sample(state, action, nextState);
276  }
277 
284  {
285  stepsPerformed = 0;
286  return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
287  }
288 
295  bool IsTerminal(const State& state) const
296  {
297  if (maxSteps != 0 && stepsPerformed >= maxSteps)
298  {
299  Log::Info << "Episode terminated due to the maximum number of steps"
300  "being taken.";
301  return true;
302  }
303  if (std::abs(state.Position()) > xThreshold)
304  {
305  Log::Info << "Episode terminated due to cart crossing threshold";
306  return true;
307  }
308  if (std::abs(state.Angle(1)) > thetaThresholdRadians ||
309  std::abs(state.Angle(2)) > thetaThresholdRadians)
310  {
311  Log::Info << "Episode terminated due to pole falling";
312  return true;
313  }
314  return false;
315  }
316 
318  size_t StepsPerformed() const { return stepsPerformed; }
319 
321  size_t MaxSteps() const { return maxSteps; }
323  size_t& MaxSteps() { return maxSteps; }
324 
325  private:
327  double m1;
328 
330  double m2;
331 
333  double l1;
334 
336  double l2;
337 
339  double gravity;
340 
342  double massCart;
343 
345  double forceMag;
346 
348  double tau;
349 
351  double thetaThresholdRadians;
352 
354  double xThreshold;
355 
357  double doneReward;
358 
360  size_t maxSteps;
361 
363  size_t stepsPerformed;
364 };
365 
366 } // namespace rl
367 } // namespace mlpack
368 
369 #endif
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State(const arma::colvec &data)
Construct a state instance from given data.
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Definition: iterator.hpp:29
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Linear algebra utility functions, generally performed on matrices or vectors.
Implementation of action of Continuous Double Pole Cart.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
size_t StepsPerformed() const
Get the number of steps performed.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Continuous Double Pole Cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double Position() const
Get the position of the cart.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
double Velocity() const
Get the velocity of the cart.
arma::colvec Data() const
Get the internal representation of the state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Continuous Double Pole Cart instance.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
ContinuousDoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
static util::PrefixedOutStream Info
Definition: log.hpp:93
arma::colvec & Data()
Modify the internal representation of the state.
Implementation of Continuous Double Pole Cart Balancing task.
double Sample(const State &state, const Action &action)
Dynamics of Continuous Double Pole Cart.