13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_FTN_HPP 14 #define MLPACK_METHODS_RL_ENVIRONMENT_FTN_HPP 55 State(
const arma::colvec& data) : data(data)
59 arma::vec&
Data() {
return data; }
62 double Row()
const {
return data[0]; }
64 double&
Row() {
return data[0]; }
67 double Column()
const {
return data[1]; }
94 static const size_t size = 2;
105 const size_t depth = 6) :
109 FruitTree tree(depth);
129 arma::vec currentState {state.
Row(), state.
Column()};
130 arma::vec direction = std::unordered_map<Action::actions, arma::vec>({
131 { Action::actions::left, arma::vec({1, currentState(1)}) },
132 { Action::actions::right, arma::vec({1, currentState(1) + 1}) }
135 arma::vec currentNextState = currentState + direction;
136 nextState.
Row() = currentNextState[0];
137 nextState.
Column() = currentNextState[1];
143 if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
146 return fruitTree.GetReward(state);
161 return Sample(state, action, nextState);
171 return State(arma::zeros<arma::vec>(2));
182 if (maxSteps != 0 && stepsPerformed >= maxSteps)
184 Log::Info <<
"Episode terminated due to the maximum number of steps" 188 else if (state.
Row() == fruitTree.Depth())
190 Log::Info <<
"Episode terminated due to reaching leaf node.";
218 FruitTree(
const size_t depth) : depth(depth)
220 if (
std::find(validDepths.begin(), validDepths.end(), depth) == validDepths.end())
222 throw std::logic_error(
"FruitTree()::FruitTree: Invalid depth value: " + std::to_string(depth) +
" provided. " 223 "Only depth values of: 5, 6, 7 are allowed.");
226 arma::mat branches = arma::zeros((rewardSize, (
size_t) std::pow(2, depth - 1)));
227 tree = arma::join_rows(branches, Fruits());
231 size_t GetIndex(
const State& state)
const 233 return static_cast<size_t>(std::pow(2, state.
Row() - 1) + state.
Column());
237 arma::vec GetReward(
const State& state)
const 239 return tree.col(GetIndex(state));
243 arma::mat Fruits()
const {
return ConvexSetMap.at(depth); }
246 size_t Depth()
const {
return depth; }
256 const std::array<size_t, 3> validDepths {5, 6, 7};
263 size_t stepsPerformed;
double Row() const
Get value of row index of the node.
double & Column()
Modify value of column index of the node.
Implementation of action for Fruit Tree Navigation task.
Implementation of Fruit Tree Navigation state.
arma::vec & Data()
Modify the state representation.
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Linear algebra utility functions, generally performed on matrices or vectors.
size_t & MaxSteps()
Set the maximum number of steps allowed.
size_t StepsPerformed() const
Get the number of steps performed.
arma::vec Sample(const State &state, const Action &action, State &nextState)
Dynamics of the FTN System.
arma::vec Sample(const State &state, const Action &action)
Dynamics of the FTN System.
static constexpr size_t dimension
Dimension of the encoded state.
static constexpr size_t rewardSize
The reward vector consists of {Protein, Carbs, Fats, Vitamins, Minerals, Water} A total of 6 rewards...
static const std::unordered_map< size_t, arma::mat > ConvexSetMap
Maps the depth value of a Fruit Tree to it's corresponding Convex Convergence Set (CCS)...
size_t MaxSteps() const
Get the maximum number of steps allowed.
bool IsTerminal(const State &state) const
This function checks if the FTN has reached the terminal state.
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
State InitialSample()
This function does null initialization of state space.
FruitTreeNavigation(const size_t maxSteps=500, const size_t depth=6)
Construct a Fruit Tree Navigation instance using the given constants.
static util::PrefixedOutStream Info
double & Row()
Modify value of row index of the node.
State(const arma::colvec &data)
Construct a state instance from given data.
State()
Construct a state instance.
double Column() const
Get value of column index of the node.
Implementation of Fruit Tree Navigation Task.