one_step_sarsa_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_SARSA_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class OneStepSarsaWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType,
41  ActionType>;
42 
53  const UpdaterType& updater,
54  const EnvironmentType& environment,
55  const TrainingConfig& config,
56  bool deterministic):
57  updater(updater),
58  #if ENS_VERSION_MAJOR >= 2
59  updatePolicy(NULL),
60  #endif
61  environment(environment),
62  config(config),
63  deterministic(deterministic),
64  pending(config.UpdateInterval())
65  { Reset(); }
66 
73  updater(other.updater),
74  #if ENS_VERSION_MAJOR >= 2
75  updatePolicy(NULL),
76  #endif
77  environment(other.environment),
78  config(other.config),
79  deterministic(other.deterministic),
80  steps(other.steps),
81  episodeReturn(other.episodeReturn),
82  pending(other.pending),
83  pendingIndex(other.pendingIndex),
84  network(other.network),
85  state(other.state),
86  action(other.action)
87  {
88  Reset();
89 
90  #if ENS_VERSION_MAJOR >= 2
91  updatePolicy = new typename UpdaterType::template
92  Policy<arma::mat, arma::mat>(updater,
93  network.Parameters().n_rows,
94  network.Parameters().n_cols);
95  #endif
96  }
97 
104  updater(std::move(other.updater)),
105  #if ENS_VERSION_MAJOR >= 2
106  updatePolicy(NULL),
107  #endif
108  environment(std::move(other.environment)),
109  config(std::move(other.config)),
110  deterministic(std::move(other.deterministic)),
111  steps(std::move(other.steps)),
112  episodeReturn(std::move(other.episodeReturn)),
113  pending(std::move(other.pending)),
114  pendingIndex(std::move(other.pendingIndex)),
115  network(std::move(other.network)),
116  state(std::move(other.state)),
117  action(std::move(other.action))
118  {
119  #if ENS_VERSION_MAJOR >= 2
120  other.updatePolicy = NULL;
121 
122  updatePolicy = new typename UpdaterType::template
123  Policy<arma::mat, arma::mat>(updater,
124  network.Parameters().n_rows,
125  network.Parameters().n_cols);
126  #endif
127  }
128 
135  {
136  if (&other == this)
137  return *this;
138 
139  #if ENS_VERSION_MAJOR >= 2
140  delete updatePolicy;
141  #endif
142 
143  updater = other.updater;
144  environment = other.environment;
145  config = other.config;
146  deterministic = other.deterministic;
147  steps = other.steps;
148  episodeReturn = other.episodeReturn;
149  pending = other.pending;
150  pendingIndex = other.pendingIndex;
151  network = other.network;
152  state = other.state;
153  action = other.action;
154 
155  #if ENS_VERSION_MAJOR >= 2
156  updatePolicy = new typename UpdaterType::template
157  Policy<arma::mat, arma::mat>(updater,
158  network.Parameters().n_rows,
159  network.Parameters().n_cols);
160  #endif
161 
162  Reset();
163 
164  return *this;
165  }
166 
173  {
174  if (&other == this)
175  return *this;
176 
177  #if ENS_VERSION_MAJOR >= 2
178  delete updatePolicy;
179  #endif
180 
181  updater = std::move(other.updater);
182  environment = std::move(other.environment);
183  config = std::move(other.config);
184  deterministic = std::move(other.deterministic);
185  steps = std::move(other.steps);
186  episodeReturn = std::move(other.episodeReturn);
187  pending = std::move(other.pending);
188  pendingIndex = std::move(other.pendingIndex);
189  network = std::move(other.network);
190  state = std::move(other.state);
191  action = std::move(other.action);
192 
193  #if ENS_VERSION_MAJOR >= 2
194  other.updatePolicy = NULL;
195 
196  updatePolicy = new typename UpdaterType::template
197  Policy<arma::mat, arma::mat>(updater,
198  network.Parameters().n_rows,
199  network.Parameters().n_cols);
200  #endif
201 
202  return *this;
203  }
204 
209  {
210  #if ENS_VERSION_MAJOR >= 2
211  delete updatePolicy;
212  #endif
213  }
214 
219  void Initialize(NetworkType& learningNetwork)
220  {
221  #if ENS_VERSION_MAJOR == 1
222  updater.Initialize(learningNetwork.Parameters().n_rows,
223  learningNetwork.Parameters().n_cols);
224  #else
225  delete updatePolicy;
226 
227  updatePolicy = new typename UpdaterType::template
228  Policy<arma::mat, arma::mat>(updater,
229  learningNetwork.Parameters().n_rows,
230  learningNetwork.Parameters().n_cols);
231  #endif
232 
233  // Build local network.
234  network = learningNetwork;
235  }
236 
248  bool Step(NetworkType& learningNetwork,
249  NetworkType& targetNetwork,
250  size_t& totalSteps,
251  PolicyType& policy,
252  double& totalReward)
253  {
254  // Interact with the environment.
255  if (action == ActionType::size)
256  {
257  // Invalid action means we are at the beginning of an episode.
258  arma::colvec actionValue;
259  network.Predict(state.Encode(), actionValue);
260  action = policy.Sample(actionValue, deterministic);
261  }
262  StateType nextState;
263  double reward = environment.Sample(state, action, nextState);
264  bool terminal = environment.IsTerminal(nextState);
265  arma::colvec actionValue;
266  network.Predict(nextState.Encode(), actionValue);
267  ActionType nextAction = policy.Sample(actionValue, deterministic);
268 
269  episodeReturn += reward;
270  steps++;
271 
272  terminal = terminal || steps >= config.StepLimit();
273  if (deterministic)
274  {
275  if (terminal)
276  {
277  totalReward = episodeReturn;
278  Reset();
279  // Sync with latest learning network.
280  network = learningNetwork;
281  return true;
282  }
283  state = nextState;
284  action = nextAction;
285  return false;
286  }
287 
288  #pragma omp atomic
289  totalSteps++;
290 
291  pending[pendingIndex++] =
292  std::make_tuple(state, action, reward, nextState, nextAction);
293 
294  if (terminal || pendingIndex >= config.UpdateInterval())
295  {
296  // Initialize the gradient storage.
297  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
298  learningNetwork.Parameters().n_cols, arma::fill::zeros);
299  for (size_t i = 0; i < pending.size(); ++i)
300  {
301  TransitionType &transition = pending[i];
302 
303  // Compute the target state-action value.
304  arma::colvec actionValue;
305  #pragma omp critical
306  {
307  targetNetwork.Predict(
308  std::get<3>(transition).Encode(), actionValue);
309  };
310  double targetActionValue = 0;
311  if (!(terminal && i == pending.size() - 1))
312  targetActionValue = actionValue[std::get<4>(transition)];
313  targetActionValue = std::get<2>(transition) +
314  config.Discount() * targetActionValue;
315 
316  // Compute the training target for current state.
317  network.Forward(std::get<0>(transition).Encode(), actionValue);
318  actionValue[std::get<1>(transition)] = targetActionValue;
319 
320  // Compute gradient.
321  arma::mat gradients;
322  network.Backward(actionValue, gradients);
323 
324  // Accumulate gradients.
325  totalGradients += gradients;
326  }
327 
328  // Clamp the accumulated gradients.
329  totalGradients.transform(
330  [&](double gradient)
331  { return std::min(std::max(gradient, -config.GradientLimit()),
332  config.GradientLimit()); });
333 
334  // Perform async update of the global network.
335  #if ENS_VERSION_MAJOR == 1
336  updater.Update(learningNetwork.Parameters(), config.StepSize(),
337  totalGradients);
338  #else
339  updatePolicy->Update(learningNetwork.Parameters(),
340  config.StepSize(), totalGradients);
341  #endif
342 
343  // Sync the local network with the global network.
344  network = learningNetwork;
345 
346  pendingIndex = 0;
347  }
348 
349  // Update global target network.
350  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
351  {
352  #pragma omp critical
353  { targetNetwork = learningNetwork; }
354  }
355 
356  policy.Anneal();
357 
358  if (terminal)
359  {
360  totalReward = episodeReturn;
361  Reset();
362  return true;
363  }
364  state = nextState;
365  action = nextAction;
366  return false;
367  }
368 
369  private:
373  void Reset()
374  {
375  steps = 0;
376  episodeReturn = 0;
377  pendingIndex = 0;
378  state = environment.InitialSample();
379  action = ActionType::size;
380  }
381 
383  UpdaterType updater;
384  #if ENS_VERSION_MAJOR >= 2
385  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
386  #endif
387 
389  EnvironmentType environment;
390 
392  TrainingConfig config;
393 
395  bool deterministic;
396 
398  size_t steps;
399 
401  double episodeReturn;
402 
404  std::vector<TransitionType> pending;
405 
407  size_t pendingIndex;
408 
410  NetworkType network;
411 
413  StateType state;
414 
416  ActionType action;
417 };
418 
419 } // namespace rl
420 } // namespace mlpack
421 
422 #endif
endif() include($
Definition: CMakeLists.txt:22
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
OneStepSarsaWorker(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
.hpp
Definition: add_to_po.hpp:21
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
Definition: prereqs.hpp:55
size_t StepLimit() const
Get the maximum steps of each episode.
std::tuple< StateType, ActionType, double, StateType, ActionType > TransitionType
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
OneStepSarsaWorker & operator=(const OneStepSarsaWorker &other)
Copy another OneStepSarsaWorker.
OneStepSarsaWorker & operator=(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
OneStepSarsaWorker(OneStepSarsaWorker &&other)
Take ownership of another OneStepSarsaWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
double GradientLimit() const
Get the limit of update gradient.
typename EnvironmentType::Action ActionType
Forward declaration of OneStepSarsaWorker.
double StepSize() const
Get the step size of the optimizer.
OneStepSarsaWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step sarsa worker with the given parameters and environment.