13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP 14 #define MLPACK_METHODS_RL_ENVIRONMENT_DOUBLE_POLE_CART_HPP 49 State(
const arma::colvec& data) : data(data)
53 arma::colvec
Data()
const {
return data; }
55 arma::colvec&
Data() {
return data; }
68 double Angle(
const size_t i)
const {
return data[2 * i]; }
70 double&
Angle(
const size_t i) {
return data[2 * i]; }
78 const arma::colvec&
Encode()
const {
return data; }
119 const double m2 = 0.01,
120 const double l1 = 0.5,
121 const double l2 = 0.05,
122 const double gravity = 9.8,
123 const double massCart = 1.0,
124 const double forceMag = 10.0,
125 const double tau = 0.02,
126 const double thetaThresholdRadians = 36 * 2 * 3.1416 / 360,
127 const double xThreshold = 2.4,
128 const double doneReward = 0.0,
129 const size_t maxSteps = 0) :
138 thetaThresholdRadians(thetaThresholdRadians),
139 xThreshold(xThreshold),
140 doneReward(doneReward),
161 arma::vec dydx(6, arma::fill::zeros);
165 Dsdt(state, action, dydx);
166 RK4(state, action, dydx, nextState);
172 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
196 double totalForce = action ? forceMag : -forceMag;
197 double totalMass = massCart;
200 double sinTheta1 = std::sin(state.
Angle(1));
201 double sinTheta2 = std::sin(state.
Angle(2));
202 double cosTheta1 = std::cos(state.
Angle(1));
203 double cosTheta2 = std::cos(state.
Angle(2));
206 totalForce += m1 * l1 * omega1 * omega1 * sinTheta1 + 0.375 * m1 * gravity *
207 std::sin(2 * state.
Angle(1));
208 totalForce += m2 * l2 * omega2 * omega2 * sinTheta1 + 0.375 * m2 * gravity *
209 std::sin(2 * state.
Angle(2));
212 totalMass += m1 * (0.25 + 0.75 * sinTheta1 * sinTheta1);
213 totalMass += m2 * (0.25 + 0.75 * sinTheta2 * sinTheta2);
216 double xAcc = totalForce / totalMass;
220 dydx[3] = -0.75 * (xAcc * cosTheta1 + gravity * sinTheta1) / l1;
221 dydx[5] = -0.75 * (xAcc * cosTheta2 + gravity * sinTheta2) / l2;
238 const double hh = tau * 0.5;
239 const double h6 = tau / 6;
244 yt = state.
Data() + (hh * dydx);
249 yt = state.
Data() + (hh * dyt);
255 yt = state.
Data() + (tau * dym);
262 nextState.
Data() = state.
Data() + h6 * (dydx + dyt + 2 * dym);
276 return Sample(state, action, nextState);
287 return State((arma::randu<arma::vec>(6) - 0.5) / 10.0);
298 if (maxSteps != 0 && stepsPerformed >= maxSteps)
300 Log::Info <<
"Episode terminated due to the maximum number of steps" 304 if (std::abs(state.
Position()) > xThreshold)
306 Log::Info <<
"Episode terminated due to cart crossing threshold";
309 if (std::abs(state.
Angle(1)) > thetaThresholdRadians ||
310 std::abs(state.
Angle(2)) > thetaThresholdRadians)
312 Log::Info <<
"Episode terminated due to pole falling";
352 double thetaThresholdRadians;
364 size_t stepsPerformed;
double Sample(const State &state, const Action &action)
Dynamics of Double Pole Cart.
State(const arma::colvec &data)
Construct a state instance from given data.
size_t MaxSteps() const
Get the maximum number of steps allowed.
double & Velocity()
Modify the velocity of the cart.
Implementation of Double Pole Cart Balancing task.
double & Angle(const size_t i)
Modify the angle of the $i^{th}$ pole.
void Dsdt(const State &state, const Action &action, arma::vec &dydx)
This is the ordinary differential equations required for estimation of next state through RK4 method...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Position() const
Get the position of the cart.
double Angle(const size_t i) const
Get the angle of the $i^{th}$ pole.
size_t StepsPerformed() const
Get the number of steps performed.
size_t & MaxSteps()
Set the maximum number of steps allowed.
State InitialSample()
Initial state representation is randomly generated within [-0.05, 0.05].
static MLPACK_EXPORT util::PrefixedOutStream Info
Prints informational messages if –verbose is specified, prefixed with [INFO ].
void RK4(const State &state, const Action &action, arma::vec &dydx, State &nextState)
This function calls the RK4 iterative method to estimate the next state based on given ordinary diffe...
double & AngularVelocity(const size_t i)
Modify the angular velocity of the $i^{th}$ pole.
double AngularVelocity(const size_t i) const
Get the angular velocity of the $i^{th}$ pole.
bool IsTerminal(const State &state) const
This function checks if the car has reached the terminal state.
double Sample(const State &state, const Action &action, State &nextState)
Dynamics of Double Pole Cart instance.
arma::colvec & Data()
Modify the internal representation of the state.
const arma::colvec & Encode() const
Encode the state to a vector..
Implementation of the state of Double Pole Cart.
State()
Construct a state instance.
DoublePoleCart(const double m1=0.1, const double m2=0.01, const double l1=0.5, const double l2=0.05, const double gravity=9.8, const double massCart=1.0, const double forceMag=10.0, const double tau=0.02, const double thetaThresholdRadians=36 *2 *3.1416/360, const double xThreshold=2.4, const double doneReward=0.0, const size_t maxSteps=0)
Construct a Double Pole Cart instance using the given constants.
double Velocity() const
Get the velocity of the cart.
static constexpr size_t dimension
Dimension of the encoded state.
double & Position()
Modify the position of the cart.
Action
Implementation of action of Double Pole Cart.
arma::colvec Data() const
Get the internal representation of the state.