dueling_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP
13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace rl {
24 
46 template <
47  typename OutputLayerType = ann::EmptyLoss,
48  typename InitType = ann::GaussianInitialization,
49  typename CompleteNetworkType = ann::FFN<OutputLayerType, InitType>,
50  typename FeatureNetworkType = ann::MultiLayer<arma::mat>,
51  typename AdvantageNetworkType = ann::MultiLayer<arma::mat>,
52  typename ValueNetworkType = ann::MultiLayer<arma::mat>
53 >
55 {
56  public:
58  DuelingDQN() : isNoisy(false)
59  {
60  // TODO: this really ought to use a DAG network, but that's not implemented
61  // yet.
62  featureNetwork = new ann::MultiLayer<arma::mat>();
63  valueNetwork = new ann::MultiLayer<arma::mat>();
64  advantageNetwork = new ann::MultiLayer<arma::mat>();
65  concat = new ann::Concat();
66 
67  concat->Add(valueNetwork);
68  concat->Add(advantageNetwork);
69  completeNetwork.Add(featureNetwork);
70  completeNetwork.Add(concat);
71  }
72 
83  DuelingDQN(const int h1,
84  const int h2,
85  const int outputDim,
86  const bool isNoisy = false,
87  InitType init = InitType(),
88  OutputLayerType outputLayer = OutputLayerType()):
89  completeNetwork(outputLayer, init),
90  isNoisy(isNoisy)
91  {
92  featureNetwork = new ann::MultiLayer<arma::mat>();
93  featureNetwork->Add(new ann::Linear(h1));
94  featureNetwork->Add(new ann::ReLU());
95 
96  valueNetwork = new ann::MultiLayer<arma::mat>();
97  advantageNetwork = new ann::MultiLayer<arma::mat>();
98 
99  if (isNoisy)
100  {
101  noisyLayerIndex.push_back(valueNetwork->Network().size());
102  valueNetwork->Add(new ann::NoisyLinear(h2));
103  advantageNetwork->Add(new ann::NoisyLinear(h2));
104 
105  valueNetwork->Add(new ann::ReLU());
106  advantageNetwork->Add(new ann::ReLU());
107 
108  noisyLayerIndex.push_back(valueNetwork->Network().size());
109  valueNetwork->Add(new ann::NoisyLinear(1));
110  advantageNetwork->Add(new ann::NoisyLinear(outputDim));
111  }
112  else
113  {
114  valueNetwork->Add(new ann::Linear(h2));
115  valueNetwork->Add(new ann::ReLU());
116  valueNetwork->Add(new ann::Linear(1));
117 
118  advantageNetwork->Add(new ann::Linear(h2));
119  advantageNetwork->Add(new ann::ReLU());
120  advantageNetwork->Add(new ann::Linear(outputDim));
121  }
122 
123  concat = new ann::Concat();
124  concat->Add(valueNetwork);
125  concat->Add(advantageNetwork);
126 
127  completeNetwork.Add(featureNetwork);
128  completeNetwork.Add(concat);
129  }
130 
139  DuelingDQN(FeatureNetworkType& featureNetwork,
140  AdvantageNetworkType& advantageNetwork,
141  ValueNetworkType& valueNetwork,
142  const bool isNoisy = false):
143  featureNetwork(featureNetwork),
144  advantageNetwork(advantageNetwork),
145  valueNetwork(valueNetwork),
146  isNoisy(isNoisy)
147  {
148  concat = new ann::Concat();
149  concat->Add(valueNetwork);
150  concat->Add(advantageNetwork);
151  completeNetwork.Add(featureNetwork);
152  completeNetwork.Add(concat);
153  }
154 
156  DuelingDQN(const DuelingDQN& /* model */) : isNoisy(false)
157  { /* Nothing to do here. */ }
158 
160  void operator = (const DuelingDQN& model)
161  {
162  *valueNetwork = *model.valueNetwork;
163  *advantageNetwork = *model.advantageNetwork;
164  *featureNetwork = *model.featureNetwork;
165  isNoisy = model.isNoisy;
166  noisyLayerIndex = model.noisyLayerIndex;
167  }
168 
180  void Predict(const arma::mat state, arma::mat& actionValue)
181  {
182  arma::mat advantage, value, networkOutput;
183  completeNetwork.Predict(state, networkOutput);
184  value = networkOutput.row(0);
185  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
186  actionValue = advantage.each_row() + (value - arma::mean(advantage));
187  }
188 
195  void Forward(const arma::mat state, arma::mat& actionValue)
196  {
197  arma::mat advantage, value, networkOutput;
198  completeNetwork.Forward(state, networkOutput);
199  value = networkOutput.row(0);
200  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
201  actionValue = advantage.each_row() +
202  (value - arma::mean(advantage));
203  this->actionValues = actionValue;
204  }
205 
213  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
214  {
215  arma::mat gradLoss;
216  lossFunction.Backward(this->actionValues, target, gradLoss);
217 
218  arma::mat gradValue = arma::sum(gradLoss);
219  arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
220 
221  arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
222  completeNetwork.Backward(state, grad, gradient);
223  }
224 
228  void Reset(const size_t inputDimensionality = 0)
229  {
230  completeNetwork.Reset(inputDimensionality);
231  }
232 
236  void ResetNoise()
237  {
238  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
239  {
240  dynamic_cast<ann::NoisyLinear*>(
241  (valueNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
242  dynamic_cast<ann::NoisyLinear*>(
243  (advantageNetwork->Network()[noisyLayerIndex[i]]))->ResetNoise();
244  }
245  }
246 
248  const arma::mat& Parameters() const { return completeNetwork.Parameters(); }
250  arma::mat& Parameters() { return completeNetwork.Parameters(); }
251 
252  private:
254  CompleteNetworkType completeNetwork;
255 
257  ann::Concat* concat;
258 
260  FeatureNetworkType* featureNetwork;
261 
263  AdvantageNetworkType* advantageNetwork;
264 
266  ValueNetworkType* valueNetwork;
267 
269  bool isNoisy;
270 
272  std::vector<size_t> noisyLayerIndex;
273 
275  arma::mat actionValues;
276 
278  ann::MeanSquaredError lossFunction;
279 };
280 
281 } // namespace rl
282 } // namespace mlpack
283 
284 #endif
The mean squared error performance function measures the network&#39;s performance according to the mean ...
EmptyLossType< arma::mat > EmptyLoss
Definition: empty_loss.hpp:65
Implementation of the NoisyLinear layer class.
Definition: noisylinear.hpp:30
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
ConcatType< arma::mat > Concat
Definition: concat.hpp:228
Implementation of the Dueling Deep Q-Learning network.
Definition: dueling_dqn.hpp:54
void ResetNoise()
Reset the noise parameters (epsilons).
Implementation of the Linear layer class.
Definition: linear.hpp:42
Implementation of the base layer.
Definition: base_layer.hpp:66
DuelingDQN()
Default constructor.
Definition: dueling_dqn.hpp:58
void operator=(const DuelingDQN &model)
Copy assignment operator.
A "multi-layer" is a layer that is a wrapper around other layers.
Definition: multi_layer.hpp:34
void Add(Args... args)
DuelingDQN(const DuelingDQN &)
Copy constructor.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
DuelingDQN(const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
Definition: dueling_dqn.hpp:83
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
void Backward(const MatType &prediction, const MatType &target, MatType &loss)
Ordinary feed backward pass of a neural network.
void Reset(const size_t inputDimensionality=0)
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the Concat class.
Definition: concat.hpp:35