dueling_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP
13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
21 
22 namespace mlpack {
23 namespace rl {
24 
25 using namespace mlpack::ann;
26 
48 template <
49  typename OutputLayerType = EmptyLoss<>,
50  typename InitType = GaussianInitialization,
51  typename CompleteNetworkType = FFN<OutputLayerType, InitType>,
52  typename FeatureNetworkType = Sequential<>,
53  typename AdvantageNetworkType = Sequential<>,
54  typename ValueNetworkType = Sequential<>
55 >
57 {
58  public:
60  DuelingDQN() : isNoisy(false)
61  {
62  featureNetwork = new Sequential<>();
63  valueNetwork = new Sequential<>();
64  advantageNetwork = new Sequential<>();
65  concat = new Concat<>(true);
66 
67  concat->Add(valueNetwork);
68  concat->Add(advantageNetwork);
69  completeNetwork.Add(new IdentityLayer<>());
70  completeNetwork.Add(featureNetwork);
71  completeNetwork.Add(concat);
72  }
73 
85  DuelingDQN(const int inputDim,
86  const int h1,
87  const int h2,
88  const int outputDim,
89  const bool isNoisy = false,
90  InitType init = InitType(),
91  OutputLayerType outputLayer = OutputLayerType()):
92  completeNetwork(outputLayer, init),
93  isNoisy(isNoisy)
94  {
95  featureNetwork = new Sequential<>();
96  featureNetwork->Add(new Linear<>(inputDim, h1));
97  featureNetwork->Add(new ReLULayer<>());
98 
99  valueNetwork = new Sequential<>();
100  advantageNetwork = new Sequential<>();
101 
102  if (isNoisy)
103  {
104  noisyLayerIndex.push_back(valueNetwork->Model().size());
105  valueNetwork->Add(new NoisyLinear<>(h1, h2));
106  advantageNetwork->Add(new NoisyLinear<>(h1, h2));
107 
108  valueNetwork->Add(new ReLULayer<>());
109  advantageNetwork->Add(new ReLULayer<>());
110 
111  noisyLayerIndex.push_back(valueNetwork->Model().size());
112  valueNetwork->Add(new NoisyLinear<>(h2, 1));
113  advantageNetwork->Add(new NoisyLinear<>(h2, outputDim));
114  }
115  else
116  {
117  valueNetwork->Add(new Linear<>(h1, h2));
118  valueNetwork->Add(new ReLULayer<>());
119  valueNetwork->Add(new Linear<>(h2, 1));
120 
121  advantageNetwork->Add(new Linear<>(h1, h2));
122  advantageNetwork->Add(new ReLULayer<>());
123  advantageNetwork->Add(new Linear<>(h2, outputDim));
124  }
125 
126  concat = new Concat<>(true);
127  concat->Add(valueNetwork);
128  concat->Add(advantageNetwork);
129 
130  completeNetwork.Add(new IdentityLayer<>());
131  completeNetwork.Add(featureNetwork);
132  completeNetwork.Add(concat);
133  this->ResetParameters();
134  }
135 
144  DuelingDQN(FeatureNetworkType& featureNetwork,
145  AdvantageNetworkType& advantageNetwork,
146  ValueNetworkType& valueNetwork,
147  const bool isNoisy = false):
148  featureNetwork(featureNetwork),
149  advantageNetwork(advantageNetwork),
150  valueNetwork(valueNetwork),
151  isNoisy(isNoisy)
152  {
153  concat = new Concat<>(true);
154  concat->Add(valueNetwork);
155  concat->Add(advantageNetwork);
156  completeNetwork.Add(new IdentityLayer<>());
157  completeNetwork.Add(featureNetwork);
158  completeNetwork.Add(concat);
159  this->ResetParameters();
160  }
161 
163  DuelingDQN(const DuelingDQN& /* model */) : isNoisy(false)
164  { /* Nothing to do here. */ }
165 
167  void operator = (const DuelingDQN& model)
168  {
169  *valueNetwork = *model.valueNetwork;
170  *advantageNetwork = *model.advantageNetwork;
171  *featureNetwork = *model.featureNetwork;
172  isNoisy = model.isNoisy;
173  noisyLayerIndex = model.noisyLayerIndex;
174  }
175 
187  void Predict(const arma::mat state, arma::mat& actionValue)
188  {
189  arma::mat advantage, value, networkOutput;
190  completeNetwork.Predict(state, networkOutput);
191  value = networkOutput.row(0);
192  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
193  actionValue = advantage.each_row() +
194  (value - arma::mean(advantage));
195  }
196 
203  void Forward(const arma::mat state, arma::mat& actionValue)
204  {
205  arma::mat advantage, value, networkOutput;
206  completeNetwork.Forward(state, networkOutput);
207  value = networkOutput.row(0);
208  advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
209  actionValue = advantage.each_row() +
210  (value - arma::mean(advantage));
211  this->actionValues = actionValue;
212  }
213 
221  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
222  {
223  arma::mat gradLoss;
224  lossFunction.Backward(this->actionValues, target, gradLoss);
225 
226  arma::mat gradValue = arma::sum(gradLoss);
227  arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
228 
229  arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
230  completeNetwork.Backward(state, grad, gradient);
231  }
232 
237  {
238  completeNetwork.ResetParameters();
239  }
240 
244  void ResetNoise()
245  {
246  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
247  {
248  boost::get<NoisyLinear<>*>
249  (valueNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
250  boost::get<NoisyLinear<>*>
251  (advantageNetwork->Model()[noisyLayerIndex[i]])->ResetNoise();
252  }
253  }
254 
256  const arma::mat& Parameters() const { return completeNetwork.Parameters(); }
258  arma::mat& Parameters() { return completeNetwork.Parameters(); }
259 
260  private:
262  CompleteNetworkType completeNetwork;
263 
265  Concat<>* concat;
266 
268  FeatureNetworkType* featureNetwork;
269 
271  AdvantageNetworkType* advantageNetwork;
272 
274  ValueNetworkType* valueNetwork;
275 
277  bool isNoisy;
278 
280  std::vector<size_t> noisyLayerIndex;
281 
283  arma::mat actionValues;
284 
286  MeanSquaredError<> lossFunction;
287 };
288 
289 } // namespace rl
290 } // namespace mlpack
291 
292 #endif
Artificial Neural Network.
DuelingDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
Definition: dueling_dqn.hpp:85
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:89
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Implementation of the Dueling Deep Q-Learning network.
Definition: dueling_dqn.hpp:56
The empty loss does nothing, letting the user calculate the loss outside the model.
Definition: empty_loss.hpp:35
Implementation of the base layer.
Definition: base_layer.hpp:69
DuelingDQN()
Default constructor.
Definition: dueling_dqn.hpp:60
Implementation of the Concat class.
Definition: concat.hpp:45
Implementation of the NoisyLinear layer class.
DuelingDQN(const DuelingDQN &)
Copy constructor.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
The mean squared error performance function measures the network&#39;s performance according to the mean ...
void ResetParameters()
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
const arma::mat & Parameters() const
Return the Parameters.
void Add(Args... args)
Definition: sequential.hpp:143
Implementation of the Sequential class.
void Add(Args... args)
Definition: concat.hpp:147
This class is used to initialize weigth matrix with a gaussian.