one_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_ONE_STEP_Q_LEARNING_WORKER_HPP
15 
16 #include <ensmallen.hpp>
18 
19 namespace mlpack {
20 namespace rl {
21 
30 template <
31  typename EnvironmentType,
32  typename NetworkType,
33  typename UpdaterType,
34  typename PolicyType
35 >
36 class OneStepQLearningWorker
37 {
38  public:
39  using StateType = typename EnvironmentType::State;
40  using ActionType = typename EnvironmentType::Action;
41  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
42 
53  const UpdaterType& updater,
54  const EnvironmentType& environment,
55  const TrainingConfig& config,
56  bool deterministic):
57  updater(updater),
58  #if ENS_VERSION_MAJOR >= 2
59  updatePolicy(NULL),
60  #endif
61  environment(environment),
62  config(config),
63  deterministic(deterministic),
64  pending(config.UpdateInterval())
65  { Reset(); }
66 
73  updater(other.updater),
74  #if ENS_VERSION_MAJOR >= 2
75  updatePolicy(NULL),
76  #endif
77  environment(other.environment),
78  config(other.config),
79  deterministic(other.deterministic),
80  steps(other.steps),
81  episodeReturn(other.episodeReturn),
82  pending(other.pending),
83  pendingIndex(other.pendingIndex),
84  network(other.network),
85  state(other.state)
86  {
87  #if ENS_VERSION_MAJOR >= 2
88  updatePolicy = new typename UpdaterType::template
89  Policy<arma::mat, arma::mat>(updater,
90  network.Parameters().n_rows,
91  network.Parameters().n_cols);
92  #endif
93 
94  Reset();
95  }
96 
103  updater(std::move(other.updater)),
104  #if ENS_VERSION_MAJOR >= 2
105  updatePolicy(NULL),
106  #endif
107  environment(std::move(other.environment)),
108  config(std::move(other.config)),
109  deterministic(std::move(other.deterministic)),
110  steps(std::move(other.steps)),
111  episodeReturn(std::move(other.episodeReturn)),
112  pending(std::move(other.pending)),
113  pendingIndex(std::move(other.pendingIndex)),
114  network(std::move(other.network)),
115  state(std::move(other.state))
116  {
117  #if ENS_VERSION_MAJOR >= 2
118  other.updatePolicy = NULL;
119 
120  updatePolicy = new typename UpdaterType::template
121  Policy<arma::mat, arma::mat>(updater,
122  network.Parameters().n_rows,
123  network.Parameters().n_cols);
124  #endif
125  }
126 
133  {
134  if (&other == this)
135  return *this;
136 
137  #if ENS_VERSION_MAJOR >= 2
138  delete updatePolicy;
139  #endif
140 
141  updater = other.updater;
142  environment = other.environment;
143  config = other.config;
144  deterministic = other.deterministic;
145  steps = other.steps;
146  episodeReturn = other.episodeReturn;
147  pending = other.pending;
148  pendingIndex = other.pendingIndex;
149  network = other.network;
150  state = other.state;
151 
152  #if ENS_VERSION_MAJOR >= 2
153  updatePolicy = new typename UpdaterType::template
154  Policy<arma::mat, arma::mat>(updater,
155  network.Parameters().n_rows,
156  network.Parameters().n_cols);
157  #endif
158 
159  Reset();
160 
161  return *this;
162  }
163 
170  {
171  if (&other == this)
172  return *this;
173 
174  #if ENS_VERSION_MAJOR >= 2
175  delete updatePolicy;
176  #endif
177 
178  updater = std::move(other.updater);
179  environment = std::move(other.environment);
180  config = std::move(other.config);
181  deterministic = std::move(other.deterministic);
182  steps = std::move(other.steps);
183  episodeReturn = std::move(other.episodeReturn);
184  pending = std::move(other.pending);
185  pendingIndex = std::move(other.pendingIndex);
186  network = std::move(other.network);
187  state = std::move(other.state);
188 
189  #if ENS_VERSION_MAJOR >= 2
190  other.updatePolicy = NULL;
191 
192  updatePolicy = new typename UpdaterType::template
193  Policy<arma::mat, arma::mat>(updater,
194  network.Parameters().n_rows,
195  network.Parameters().n_cols);
196  #endif
197 
198  return *this;
199  }
200 
205  {
206  #if ENS_VERSION_MAJOR >= 2
207  delete updatePolicy;
208  #endif
209  }
210 
215  void Initialize(NetworkType& learningNetwork)
216  {
217  #if ENS_VERSION_MAJOR == 1
218  updater.Initialize(learningNetwork.Parameters().n_rows,
219  learningNetwork.Parameters().n_cols);
220  #else
221  delete updatePolicy;
222 
223  updatePolicy = new typename UpdaterType::template
224  Policy<arma::mat, arma::mat>(updater,
225  learningNetwork.Parameters().n_rows,
226  learningNetwork.Parameters().n_cols);
227  #endif
228 
229  // Build local network.
230  network = learningNetwork;
231  }
232 
244  bool Step(NetworkType& learningNetwork,
245  NetworkType& targetNetwork,
246  size_t& totalSteps,
247  PolicyType& policy,
248  double& totalReward)
249  {
250  // Interact with the environment.
251  arma::colvec actionValue;
252  network.Predict(state.Encode(), actionValue);
253  ActionType action = policy.Sample(actionValue, deterministic);
254  StateType nextState;
255  double reward = environment.Sample(state, action, nextState);
256  bool terminal = environment.IsTerminal(nextState);
257 
258  episodeReturn += reward;
259  steps++;
260 
261  terminal = terminal || steps >= config.StepLimit();
262  if (deterministic)
263  {
264  if (terminal)
265  {
266  totalReward = episodeReturn;
267  Reset();
268  // Sync with latest learning network.
269  network = learningNetwork;
270  return true;
271  }
272  state = nextState;
273  return false;
274  }
275 
276  #pragma omp atomic
277  totalSteps++;
278 
279  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
280  pendingIndex++;
281 
282  if (terminal || pendingIndex >= config.UpdateInterval())
283  {
284  // Initialize the gradient storage.
285  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
286  learningNetwork.Parameters().n_cols, arma::fill::zeros);
287  for (size_t i = 0; i < pending.size(); ++i)
288  {
289  TransitionType &transition = pending[i];
290 
291  // Compute the target state-action value.
292  arma::colvec actionValue;
293  #pragma omp critical
294  {
295  targetNetwork.Predict(
296  std::get<3>(transition).Encode(), actionValue);
297  };
298  double targetActionValue = actionValue.max();
299  if (terminal && i == pending.size() - 1)
300  targetActionValue = 0;
301  targetActionValue = std::get<2>(transition) +
302  config.Discount() * targetActionValue;
303 
304  // Compute the training target for current state.
305  arma::mat input = std::get<0>(transition).Encode();
306  network.Forward(input, actionValue);
307  actionValue[std::get<1>(transition).action] = targetActionValue;
308 
309  // Compute gradient.
310  arma::mat gradients;
311  network.Backward(input, actionValue, gradients);
312 
313  // Accumulate gradients.
314  totalGradients += gradients;
315  }
316 
317  // Clamp the accumulated gradients.
318  totalGradients.transform(
319  [&](double gradient)
320  { return std::min(std::max(gradient, -config.GradientLimit()),
321  config.GradientLimit()); });
322 
323  // Perform async update of the global network.
324  #if ENS_VERSION_MAJOR == 1
325  updater.Update(learningNetwork.Parameters(), config.StepSize(),
326  totalGradients);
327  #else
328  updatePolicy->Update(learningNetwork.Parameters(),
329  config.StepSize(), totalGradients);
330  #endif
331 
332  // Sync the local network with the global network.
333  network = learningNetwork;
334 
335  pendingIndex = 0;
336  }
337 
338  // Update global target network.
339  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
340  {
341  #pragma omp critical
342  { targetNetwork = learningNetwork; }
343  }
344 
345  policy.Anneal();
346 
347  if (terminal)
348  {
349  totalReward = episodeReturn;
350  Reset();
351  return true;
352  }
353  state = nextState;
354  return false;
355  }
356 
357  private:
361  void Reset()
362  {
363  steps = 0;
364  episodeReturn = 0;
365  pendingIndex = 0;
366  state = environment.InitialSample();
367  }
368 
370  UpdaterType updater;
371  #if ENS_VERSION_MAJOR >= 2
372  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
373  #endif
374 
376  EnvironmentType environment;
377 
379  TrainingConfig config;
380 
382  bool deterministic;
383 
385  size_t steps;
386 
388  double episodeReturn;
389 
391  std::vector<TransitionType> pending;
392 
394  size_t pendingIndex;
395 
397  NetworkType network;
398 
400  StateType state;
401 };
402 
403 } // namespace rl
404 } // namespace mlpack
405 
406 #endif
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
Linear algebra utility functions, generally performed on matrices or vectors.
std::tuple< StateType, ActionType, double, StateType > TransitionType
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
constexpr T const & max(T const &lhs, T const &rhs)
Definition: algorithm.hpp:79
OneStepQLearningWorker & operator=(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
OneStepQLearningWorker(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
OneStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct one step Q-Learning worker with the given parameters and environment.
Forward declaration of OneStepQLearningWorker.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::Action ActionType
OneStepQLearningWorker(OneStepQLearningWorker &&other)
Take ownership of another OneStepQLearningWorker.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
constexpr T const & min(T const &lhs, T const &rhs)
Definition: algorithm.hpp:69
cannot build Julia bindings endif() else() find_package(Julia 0.7.0) if(NOT JULIA_FOUND) unset(BUILD_JULIA_BINDINGS CACHE) endif() endif() if(NOT JULIA_FOUND) not_found_return("Julia not found
Definition: CMakeLists.txt:45
OneStepQLearningWorker & operator=(const OneStepQLearningWorker &other)
Copy another OneStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
auto move(Range &&rng, OutputIt &&it) -> enable_if_t< is_range< Range >::value, decay_t< OutputIt > >
Definition: algorithm.hpp:736
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
double StepSize() const
Get the step size of the optimizer.
typename EnvironmentType::State StateType