one_step_sarsa_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
15 
16 #include <ensmallen.hpp>
18 
19 namespace mlpack {
20 namespace rl {
21 
30 template <
31  typename EnvironmentType,
32  typename NetworkType,
33  typename UpdaterType,
34  typename PolicyType
35 >
36 class OneStepSarsaWorker
37 {
38  public:
39  using StateType = typename EnvironmentType::State;
40  using ActionType = typename EnvironmentType::Action;
41  using TransitionType = std::tuple<StateType, ActionType, double, StateType,
42  ActionType>;
43 
54  const UpdaterType& updater,
55  const EnvironmentType& environment,
56  const TrainingConfig& config,
57  bool deterministic):
58  updater(updater),
59  #if ENS_VERSION_MAJOR >= 2
60  updatePolicy(NULL),
61  #endif
62  environment(environment),
63  config(config),
64  deterministic(deterministic),
65  pending(config.UpdateInterval())
66  { Reset(); }
67 
74  updater(other.updater),
75  #if ENS_VERSION_MAJOR >= 2
76  updatePolicy(NULL),
77  #endif
78  environment(other.environment),
79  config(other.config),
80  deterministic(other.deterministic),
81  steps(other.steps),
82  episodeReturn(other.episodeReturn),
83  pending(other.pending),
84  pendingIndex(other.pendingIndex),
85  network(other.network),
86  state(other.state),
87  action(other.action)
88  {
89  Reset();
90 
91  #if ENS_VERSION_MAJOR >= 2
92  updatePolicy = new typename UpdaterType::template
93  Policy<arma::mat, arma::mat>(updater,
94  network.Parameters().n_rows,
95  network.Parameters().n_cols);
96  #endif
97  }
98 
105  updater(std::move(other.updater)),
106  #if ENS_VERSION_MAJOR >= 2
107  updatePolicy(NULL),
108  #endif
109  environment(std::move(other.environment)),
110  config(std::move(other.config)),
111  deterministic(std::move(other.deterministic)),
112  steps(std::move(other.steps)),
113  episodeReturn(std::move(other.episodeReturn)),
114  pending(std::move(other.pending)),
115  pendingIndex(std::move(other.pendingIndex)),
116  network(std::move(other.network)),
117  state(std::move(other.state)),
118  action(std::move(other.action))
119  {
120  #if ENS_VERSION_MAJOR >= 2
121  other.updatePolicy = NULL;
122 
123  updatePolicy = new typename UpdaterType::template
124  Policy<arma::mat, arma::mat>(updater,
125  network.Parameters().n_rows,
126  network.Parameters().n_cols);
127  #endif
128  }
129 
136  {
137  if (&other == this)
138  return *this;
139 
140  #if ENS_VERSION_MAJOR >= 2
141  delete updatePolicy;
142  #endif
143 
144  updater = other.updater;
145  environment = other.environment;
146  config = other.config;
147  deterministic = other.deterministic;
148  steps = other.steps;
149  episodeReturn = other.episodeReturn;
150  pending = other.pending;
151  pendingIndex = other.pendingIndex;
152  network = other.network;
153  state = other.state;
154  action = other.action;
155 
156  #if ENS_VERSION_MAJOR >= 2
157  updatePolicy = new typename UpdaterType::template
158  Policy<arma::mat, arma::mat>(updater,
159  network.Parameters().n_rows,
160  network.Parameters().n_cols);
161  #endif
162 
163  Reset();
164 
165  return *this;
166  }
167 
174  {
175  if (&other == this)
176  return *this;
177 
178  #if ENS_VERSION_MAJOR >= 2
179  delete updatePolicy;
180  #endif
181 
182  updater = std::move(other.updater);
183  environment = std::move(other.environment);
184  config = std::move(other.config);
185  deterministic = std::move(other.deterministic);
186  steps = std::move(other.steps);
187  episodeReturn = std::move(other.episodeReturn);
188  pending = std::move(other.pending);
189  pendingIndex = std::move(other.pendingIndex);
190  network = std::move(other.network);
191  state = std::move(other.state);
192  action = std::move(other.action);
193 
194  #if ENS_VERSION_MAJOR >= 2
195  other.updatePolicy = NULL;
196 
197  updatePolicy = new typename UpdaterType::template
198  Policy<arma::mat, arma::mat>(updater,
199  network.Parameters().n_rows,
200  network.Parameters().n_cols);
201  #endif
202 
203  return *this;
204  }
205 
210  {
211  #if ENS_VERSION_MAJOR >= 2
212  delete updatePolicy;
213  #endif
214  }
215 
220  void Initialize(NetworkType& learningNetwork)
221  {
222  #if ENS_VERSION_MAJOR == 1
223  updater.Initialize(learningNetwork.Parameters().n_rows,
224  learningNetwork.Parameters().n_cols);
225  #else
226  delete updatePolicy;
227 
228  updatePolicy = new typename UpdaterType::template
229  Policy<arma::mat, arma::mat>(updater,
230  learningNetwork.Parameters().n_rows,
231  learningNetwork.Parameters().n_cols);
232  #endif
233 
234  // Build local network.
235  network = learningNetwork;
236  }
237 
249  bool Step(NetworkType& learningNetwork,
250  NetworkType& targetNetwork,
251  size_t& totalSteps,
252  PolicyType& policy,
253  double& totalReward)
254  {
255  // Interact with the environment.
256  if (action.action == ActionType::size)
257  {
258  // Invalid action means we are at the beginning of an episode.
259  arma::colvec actionValue;
260  network.Predict(state.Encode(), actionValue);
261  action = policy.Sample(actionValue, deterministic);
262  }
263  StateType nextState;
264  double reward = environment.Sample(state, action, nextState);
265  bool terminal = environment.IsTerminal(nextState);
266  arma::colvec actionValue;
267  network.Predict(nextState.Encode(), actionValue);
268  ActionType nextAction = policy.Sample(actionValue, deterministic);
269 
270  episodeReturn += reward;
271  steps++;
272 
273  terminal = terminal || steps >= config.StepLimit();
274  if (deterministic)
275  {
276  if (terminal)
277  {
278  totalReward = episodeReturn;
279  Reset();
280  // Sync with latest learning network.
281  network = learningNetwork;
282  return true;
283  }
284  state = nextState;
285  action = nextAction;
286  return false;
287  }
288 
289  #pragma omp atomic
290  totalSteps++;
291 
292  pending[pendingIndex++] =
293  std::make_tuple(state, action, reward, nextState, nextAction);
294 
295  if (terminal || pendingIndex >= config.UpdateInterval())
296  {
297  // Initialize the gradient storage.
298  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
299  learningNetwork.Parameters().n_cols, arma::fill::zeros);
300  for (size_t i = 0; i < pending.size(); ++i)
301  {
302  TransitionType &transition = pending[i];
303 
304  // Compute the target state-action value.
305  arma::colvec actionValue;
306  #pragma omp critical
307  {
308  targetNetwork.Predict(
309  std::get<3>(transition).Encode(), actionValue);
310  };
311  double targetActionValue = 0;
312  if (!(terminal && i == pending.size() - 1))
313  targetActionValue = actionValue[std::get<4>(transition).action];
314  targetActionValue = std::get<2>(transition) +
315  config.Discount() * targetActionValue;
316 
317  // Compute the training target for current state.
318  arma::mat input = std::get<0>(transition).Encode();
319  network.Forward(input, actionValue);
320  actionValue[std::get<1>(transition).action] = targetActionValue;
321 
322  // Compute gradient.
323  arma::mat gradients;
324  network.Backward(input, actionValue, gradients);
325 
326  // Accumulate gradients.
327  totalGradients += gradients;
328  }
329 
330  // Clamp the accumulated gradients.
331  totalGradients.transform(
332  [&](double gradient)
333  { return std::min(std::max(gradient, -config.GradientLimit()),
334  config.GradientLimit()); });
335 
336  // Perform async update of the global network.
337  #if ENS_VERSION_MAJOR == 1
338  updater.Update(learningNetwork.Parameters(), config.StepSize(),
339  totalGradients);
340  #else
341  updatePolicy->Update(learningNetwork.Parameters(),
342  config.StepSize(), totalGradients);
343  #endif
344 
345  // Sync the local network with the global network.
346  network = learningNetwork;
347 
348  pendingIndex = 0;
349  }
350 
351  // Update global target network.
352  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
353  {
354  #pragma omp critical
355  { targetNetwork = learningNetwork; }
356  }
357 
358  policy.Anneal();
359 
360  if (terminal)
361  {
362  totalReward = episodeReturn;
363  Reset();
364  return true;
365  }
366  state = nextState;
367  action = nextAction;
368  return false;
369  }
370 
371  private:
375  void Reset()
376  {
377  steps = 0;
378  episodeReturn = 0;
379  pendingIndex = 0;
380  state = environment.InitialSample();
381  using actions = typename EnvironmentType::Action::actions;
382  action.action = static_cast<actions>(ActionType::size);
383  }
384 
386  UpdaterType updater;
387  #if ENS_VERSION_MAJOR >= 2
388  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
389  #endif
390 
392  EnvironmentType environment;
393 
395  TrainingConfig config;
396 
398  bool deterministic;
399 
401  size_t steps;
402 
404  double episodeReturn;
405 
407  std::vector<TransitionType> pending;
408 
410  size_t pendingIndex;
411 
413  NetworkType network;
414 
416  StateType state;
417 
419  ActionType action;
420 };
421 
422 } // namespace rl
423 } // namespace mlpack
424 
425 #endif
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
constexpr auto size(Container const &container) noexcept -> decltype(container.size())
Definition: iterator.hpp:29
Linear algebra utility functions, generally performed on matrices or vectors.
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
size_t StepLimit() const
Get the maximum steps of each episode.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
constexpr T const & max(T const &lhs, T const &rhs)
Definition: algorithm.hpp:79
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
constexpr T const & min(T const &lhs, T const &rhs)
Definition: algorithm.hpp:69
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Definition: CMakeLists.txt:45
double GradientLimit() const
Get the limit of update gradient.
auto move(Range &&rng, OutputIt &&it) -> enable_if_t< is_range< Range >::value, decay_t< OutputIt > >
Definition: algorithm.hpp:736
typename EnvironmentType::Action ActionType
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
Forward declaration of OneStepSarsaWorker.
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.