n_step_q_learning_worker.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
14 #define MLPACK_METHODS_RL_WORKER_N_STEP_Q_LEARNING_WORKER_HPP
15 
17 
18 namespace mlpack {
19 namespace rl {
20 
29 template <
30  typename EnvironmentType,
31  typename NetworkType,
32  typename UpdaterType,
33  typename PolicyType
34 >
35 class NStepQLearningWorker
36 {
37  public:
38  using StateType = typename EnvironmentType::State;
39  using ActionType = typename EnvironmentType::Action;
40  using TransitionType = std::tuple<StateType, ActionType, double, StateType>;
41 
52  const UpdaterType& updater,
53  const EnvironmentType& environment,
54  const TrainingConfig& config,
55  bool deterministic):
56  updater(updater),
57  #if ENS_VERSION_MAJOR >= 2
58  updatePolicy(NULL),
59  #endif
60  environment(environment),
61  config(config),
62  deterministic(deterministic),
63  pending(config.UpdateInterval())
64  { Reset(); }
65 
72  updater(other.updater),
73  #if ENS_VERSION_MAJOR >= 2
74  updatePolicy(NULL),
75  #endif
76  environment(other.environment),
77  config(other.config),
78  deterministic(other.deterministic),
79  steps(other.steps),
80  episodeReturn(other.episodeReturn),
81  pending(other.pending),
82  pendingIndex(other.pendingIndex),
83  network(other.network),
84  state(other.state)
85  {
86  #if ENS_VERSION_MAJOR >= 2
87  updatePolicy = new typename UpdaterType::template
88  Policy<arma::mat, arma::mat>(updater,
89  network.Parameters().n_rows,
90  network.Parameters().n_cols);
91  #endif
92 
93  Reset();
94  }
95 
102  updater(std::move(other.updater)),
103  #if ENS_VERSION_MAJOR >= 2
104  updatePolicy(NULL),
105  #endif
106  environment(std::move(other.environment)),
107  config(std::move(other.config)),
108  deterministic(std::move(other.deterministic)),
109  steps(std::move(other.steps)),
110  episodeReturn(std::move(other.episodeReturn)),
111  pending(std::move(other.pending)),
112  pendingIndex(std::move(other.pendingIndex)),
113  network(std::move(other.network)),
114  state(std::move(other.state))
115  {
116  #if ENS_VERSION_MAJOR >= 2
117  other.updatePolicy = NULL;
118 
119  updatePolicy = new typename UpdaterType::template
120  Policy<arma::mat, arma::mat>(updater,
121  network.Parameters().n_rows,
122  network.Parameters().n_cols);
123  #endif
124  }
125 
132  {
133  if (&other == this)
134  return *this;
135 
136  #if ENS_VERSION_MAJOR >= 2
137  delete updatePolicy;
138  #endif
139 
140  updater = other.updater;
141  environment = other.environment;
142  config = other.config;
143  deterministic = other.deterministic;
144  steps = other.steps;
145  episodeReturn = other.episodeReturn;
146  pending = other.pending;
147  pendingIndex = other.pendingIndex;
148  network = other.network;
149  state = other.state;
150 
151  #if ENS_VERSION_MAJOR >= 2
152  updatePolicy = new typename UpdaterType::template
153  Policy<arma::mat, arma::mat>(updater,
154  network.Parameters().n_rows,
155  network.Parameters().n_cols);
156  #endif
157 
158  Reset();
159 
160  return *this;
161  }
162 
169  {
170  if (&other == this)
171  return *this;
172 
173  #if ENS_VERSION_MAJOR >= 2
174  delete updatePolicy;
175  #endif
176 
177  updater = std::move(other.updater);
178  environment = std::move(other.environment);
179  config = std::move(other.config);
180  deterministic = std::move(other.deterministic);
181  steps = std::move(other.steps);
182  episodeReturn = std::move(other.episodeReturn);
183  pending = std::move(other.pending);
184  pendingIndex = std::move(other.pendingIndex);
185  network = std::move(other.network);
186  state = std::move(other.state);
187 
188  #if ENS_VERSION_MAJOR >= 2
189  updatePolicy = new typename UpdaterType::template
190  Policy<arma::mat, arma::mat>(updater,
191  network.Parameters().n_rows,
192  network.Parameters().n_cols);
193 
194  other.updatePolicy = NULL;
195  #endif
196 
197  return *this;
198  }
199 
204  {
205  #if ENS_VERSION_MAJOR >= 2
206  delete updatePolicy;
207  #endif
208  }
209 
214  void Initialize(NetworkType& learningNetwork)
215  {
216  #if ENS_VERSION_MAJOR == 1
217  updater.Initialize(learningNetwork.Parameters().n_rows,
218  learningNetwork.Parameters().n_cols);
219  #else
220  delete updatePolicy;
221 
222  updatePolicy = new typename UpdaterType::template
223  Policy<arma::mat, arma::mat>(updater,
224  learningNetwork.Parameters().n_rows,
225  learningNetwork.Parameters().n_cols);
226  #endif
227 
228  // Build local network.
229  network = learningNetwork;
230  }
231 
243  bool Step(NetworkType& learningNetwork,
244  NetworkType& targetNetwork,
245  size_t& totalSteps,
246  PolicyType& policy,
247  double& totalReward)
248  {
249  // Interact with the environment.
250  arma::colvec actionValue;
251  network.Predict(state.Encode(), actionValue);
252  ActionType action = policy.Sample(actionValue, deterministic);
253  StateType nextState;
254  double reward = environment.Sample(state, action, nextState);
255  bool terminal = environment.IsTerminal(nextState);
256 
257  episodeReturn += reward;
258  steps++;
259 
260  terminal = terminal || steps >= config.StepLimit();
261  if (deterministic)
262  {
263  if (terminal)
264  {
265  totalReward = episodeReturn;
266  Reset();
267  // Sync with latest learning network.
268  network = learningNetwork;
269  return true;
270  }
271  state = nextState;
272  return false;
273  }
274 
275  #pragma omp atomic
276  totalSteps++;
277 
278  pending[pendingIndex] = std::make_tuple(state, action, reward, nextState);
279  pendingIndex++;
280 
281  if (terminal || pendingIndex >= config.UpdateInterval())
282  {
283  // Initialize the gradient storage.
284  arma::mat totalGradients(learningNetwork.Parameters().n_rows,
285  learningNetwork.Parameters().n_cols, arma::fill::zeros);
286 
287  // Bootstrap from the value of next state.
288  arma::colvec actionValue;
289  double target = 0;
290  if (!terminal)
291  {
292  #pragma omp critical
293  { targetNetwork.Predict(nextState.Encode(), actionValue); };
294  target = actionValue.max();
295  }
296 
297  // Update in reverse order.
298  for (int i = pending.size() - 1; i >= 0; --i)
299  {
300  TransitionType &transition = pending[i];
301  target = config.Discount() * target + std::get<2>(transition);
302 
303  // Compute the training target for current state.
304  network.Forward(std::get<0>(transition).Encode(), actionValue);
305  actionValue[std::get<1>(transition)] = target;
306 
307  // Compute gradient.
308  arma::mat gradients;
309  network.Backward(actionValue, gradients);
310 
311  // Accumulate gradients.
312  totalGradients += gradients;
313  }
314 
315  // Clamp the accumulated gradients.
316  totalGradients.transform(
317  [&](double gradient)
318  { return std::min(std::max(gradient, -config.GradientLimit()),
319  config.GradientLimit()); });
320 
321  // Perform async update of the global network.
322  #if ENS_VERSION_MAJOR == 1
323  updater.Update(learningNetwork.Parameters(), config.StepSize(),
324  totalGradients);
325  #else
326  updatePolicy->Update(learningNetwork.Parameters(),
327  config.StepSize(), totalGradients);
328  #endif
329 
330  // Sync the local network with the global network.
331  network = learningNetwork;
332 
333  pendingIndex = 0;
334  }
335 
336  // Update global target network.
337  if (totalSteps % config.TargetNetworkSyncInterval() == 0)
338  {
339  #pragma omp critical
340  { targetNetwork = learningNetwork; }
341  }
342 
343  policy.Anneal();
344 
345  if (terminal)
346  {
347  totalReward = episodeReturn;
348  Reset();
349  return true;
350  }
351  state = nextState;
352  return false;
353  }
354 
355  private:
359  void Reset()
360  {
361  steps = 0;
362  episodeReturn = 0;
363  pendingIndex = 0;
364  state = environment.InitialSample();
365  }
366 
368  UpdaterType updater;
369  #if ENS_VERSION_MAJOR >= 2
370  typename UpdaterType::template Policy<arma::mat, arma::mat>* updatePolicy;
371  #endif
372 
374  EnvironmentType environment;
375 
377  TrainingConfig config;
378 
380  bool deterministic;
381 
383  size_t steps;
384 
386  double episodeReturn;
387 
389  std::vector<TransitionType> pending;
390 
392  size_t pendingIndex;
393 
395  NetworkType network;
396 
398  StateType state;
399 };
400 
401 } // namespace rl
402 } // namespace mlpack
403 
404 #endif
NStepQLearningWorker(const NStepQLearningWorker &other)
Copy another NStepQLearningWorker.
std::tuple< StateType, ActionType, double, StateType > TransitionType
NStepQLearningWorker(const UpdaterType &updater, const EnvironmentType &environment, const TrainingConfig &config, bool deterministic)
Construct N-step Q-Learning worker with the given parameters and environment.
endif() include($
Definition: CMakeLists.txt:22
.hpp
Definition: add_to_po.hpp:21
NStepQLearningWorker(NStepQLearningWorker &&other)
Take ownership of another NStepQLearningWorker.
void Initialize(NetworkType &learningNetwork)
Initialize the worker.
Definition: prereqs.hpp:55
NStepQLearningWorker & operator=(NStepQLearningWorker &&other)
Take ownership of another NStepQLearningWorker.
NStepQLearningWorker & operator=(const NStepQLearningWorker &other)
Copy another NStepQLearningWorker.
size_t StepLimit() const
Get the maximum steps of each episode.
size_t TargetNetworkSyncInterval() const
Get the interval for syncing target network.
bool Step(NetworkType &learningNetwork, NetworkType &targetNetwork, size_t &totalSteps, PolicyType &policy, double &totalReward)
The agent will execute one step.
size_t UpdateInterval() const
Get the update interval.
double Discount() const
Get the discount rate for future reward.
typename EnvironmentType::State StateType
typename EnvironmentType::Action ActionType
Forward declaration of NStepQLearningWorker.
double GradientLimit() const
Get the limit of update gradient.
double StepSize() const
Get the step size of the optimizer.