ftn.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_ENVIRONMENT_FTN_HPP
14 #define MLPACK_METHODS_RL_ENVIRONMENT_FTN_HPP
15 
16 #include <mlpack/core.hpp>
17 #include "ftn_param.hpp"
18 
19 namespace mlpack{
20 namespace rl{
21 
36 {
37  public:
42  class State
43  {
44  public:
48  State(): data(dimension) { /* Nothing to do here. */ }
49 
55  State(const arma::colvec& data) : data(data)
56  { /* Nothing to do here. */ }
57 
59  arma::vec& Data() { return data; }
60 
62  double Row() const { return data[0]; }
64  double& Row() { return data[0]; }
65 
67  double Column() const { return data[1]; }
69  double& Column() { return data[1]; }
70 
72  static constexpr size_t dimension = 2;
73 
74  private:
76  arma::vec data;
77  };
78 
82  class Action
83  {
84  public:
85  enum actions
86  {
89  };
90  // To store the action.
92 
93  // Track the size of the action space.
94  static const size_t size = 2;
95  };
96 
104  FruitTreeNavigation(const size_t maxSteps = 500,
105  const size_t depth = 6) :
106  maxSteps(maxSteps),
107  stepsPerformed(0)
108  {
109  FruitTree tree(depth);
110  }
111 
121  arma::vec Sample(const State& state,
122  const Action& action,
123  State& nextState)
124  {
125  // Update the number of steps performed.
126  stepsPerformed++;
127 
128  // Make a vector to estimate nextstate.
129  arma::vec currentState {state.Row(), state.Column()};
130  arma::vec direction = std::unordered_map<Action::actions, arma::vec>({
131  { Action::actions::left, arma::vec({1, currentState(1)}) },
132  { Action::actions::right, arma::vec({1, currentState(1) + 1}) }
133  })[action.action];
134 
135  arma::vec currentNextState = currentState + direction;
136  nextState.Row() = currentNextState[0];
137  nextState.Column() = currentNextState[1];
138 
139  // Check if the episode has terminated.
140  bool done = IsTerminal(nextState);
141 
142  // Do not reward the agent if time ran out.
143  if (done && maxSteps != 0 && stepsPerformed >= maxSteps)
144  return arma::zeros(rewardSize);
145 
146  return fruitTree.GetReward(state);
147  };
148 
158  arma::vec Sample(const State& state, const Action& action)
159  {
160  State nextState;
161  return Sample(state, action, nextState);
162  }
163 
169  {
170  stepsPerformed = 0;
171  return State(arma::zeros<arma::vec>(2));
172  }
173 
180  bool IsTerminal(const State& state) const
181  {
182  if (maxSteps != 0 && stepsPerformed >= maxSteps)
183  {
184  Log::Info << "Episode terminated due to the maximum number of steps"
185  " being taken.";
186  return true;
187  }
188  else if (state.Row() == fruitTree.Depth())
189  {
190  Log::Info << "Episode terminated due to reaching leaf node.";
191  return true;
192  }
193  return false;
194  }
195 
197  size_t StepsPerformed() const { return stepsPerformed; }
198 
200  size_t MaxSteps() const { return maxSteps; }
202  size_t& MaxSteps() { return maxSteps; }
203 
206  static constexpr size_t rewardSize = 6;
207 
208  private:
209 
210  /*
211  * A Fruit Tree is a full binary tree. Each node of this tree represents a state.
212  * Non-leaf nodes yield a null vector R^{6} reward. Leaf nodes have a pre-defined
213  * set of reward vectors such that they lie on a Convex Convergence Set (CCS).
214  */
215  class FruitTree
216  {
217  public:
218  FruitTree(const size_t depth) : depth(depth)
219  {
220  if (std::find(validDepths.begin(), validDepths.end(), depth) == validDepths.end())
221  {
222  throw std::logic_error("FruitTree()::FruitTree: Invalid depth value: " + std::to_string(depth) + " provided. "
223  "Only depth values of: 5, 6, 7 are allowed.");
224  }
225 
226  arma::mat branches = arma::zeros((rewardSize, (size_t) std::pow(2, depth - 1)));
227  tree = arma::join_rows(branches, Fruits());
228  }
229 
230  // Extract array index from {(row, column)} representation.
231  size_t GetIndex(const State& state) const
232  {
233  return static_cast<size_t>(std::pow(2, state.Row() - 1) + state.Column());
234  }
235 
236  // Yield reward from the current node.
237  arma::vec GetReward(const State& state) const
238  {
239  return tree.col(GetIndex(state));
240  }
241 
243  arma::mat Fruits() const { return ConvexSetMap.at(depth); }
244 
246  size_t Depth() const { return depth; }
247 
248  private:
250  size_t depth;
251 
253  arma::mat tree;
254 
256  const std::array<size_t, 3> validDepths {5, 6, 7};
257  };
258 
260  size_t maxSteps;
261 
263  size_t stepsPerformed;
264 
266  FruitTree fruitTree;
267  };
268 };
269 
270 } // namespace rl
271 } // namespace mlpack
272 
273 #endif
double Row() const
Get value of row index of the node.
Definition: ftn.hpp:62
double & Column()
Modify value of column index of the node.
Definition: ftn.hpp:69
Implementation of action for Fruit Tree Navigation task.
Definition: ftn.hpp:82
Implementation of Fruit Tree Navigation state.
Definition: ftn.hpp:42
arma::vec & Data()
Modify the state representation.
Definition: ftn.hpp:59
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Definition: iterator.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
size_t & MaxSteps()
Set the maximum number of steps allowed.
Definition: ftn.hpp:202
size_t StepsPerformed() const
Get the number of steps performed.
Definition: ftn.hpp:197
arma::vec Sample(const State &state, const Action &action, State &nextState)
Dynamics of the FTN System.
Definition: ftn.hpp:121
arma::vec Sample(const State &state, const Action &action)
Dynamics of the FTN System.
Definition: ftn.hpp:158
static constexpr size_t dimension
Dimension of the encoded state.
Definition: ftn.hpp:72
typename impl::find< T, U >::type find
Definition: meta.hpp:238
static constexpr size_t rewardSize
The reward vector consists of {Protein, Carbs, Fats, Vitamins, Minerals, Water} A total of 6 rewards...
Definition: ftn.hpp:206
static const std::unordered_map< size_t, arma::mat > ConvexSetMap
Maps the depth value of a Fruit Tree to it&#39;s corresponding Convex Convergence Set (CCS)...
Definition: ftn_param.hpp:25
size_t MaxSteps() const
Get the maximum number of steps allowed.
Definition: ftn.hpp:200
bool IsTerminal(const State &state) const
This function checks if the FTN has reached the terminal state.
Definition: ftn.hpp:180
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
State InitialSample()
This function does null initialization of state space.
Definition: ftn.hpp:168
FruitTreeNavigation(const size_t maxSteps=500, const size_t depth=6)
Construct a Fruit Tree Navigation instance using the given constants.
Definition: ftn.hpp:104
static util::PrefixedOutStream Info
Definition: log.hpp:94
double & Row()
Modify value of row index of the node.
Definition: ftn.hpp:64
State(const arma::colvec &data)
Construct a state instance from given data.
Definition: ftn.hpp:55
State()
Construct a state instance.
Definition: ftn.hpp:48
double Column() const
Get value of column index of the node.
Definition: ftn.hpp:67
Implementation of Fruit Tree Navigation Task.
Definition: ftn.hpp:35