12 #ifndef MLPACK_METHODS_RL_DUELING_DQN_HPP 13 #define MLPACK_METHODS_RL_DUELING_DQN_HPP 48 typename InitType = ann::GaussianInitialization,
49 typename CompleteNetworkType = ann::FFN<OutputLayerType, InitType>,
50 typename FeatureNetworkType = ann::MultiLayer<arma::mat>,
51 typename AdvantageNetworkType = ann::MultiLayer<arma::mat>,
52 typename ValueNetworkType = ann::MultiLayer<arma::mat>
67 concat->
Add(valueNetwork);
68 concat->
Add(advantageNetwork);
69 completeNetwork.Add(featureNetwork);
70 completeNetwork.Add(concat);
86 const bool isNoisy =
false,
87 InitType init = InitType(),
88 OutputLayerType outputLayer = OutputLayerType()):
89 completeNetwork(outputLayer, init),
101 noisyLayerIndex.push_back(valueNetwork->Network().size());
108 noisyLayerIndex.push_back(valueNetwork->Network().size());
124 concat->
Add(valueNetwork);
125 concat->
Add(advantageNetwork);
127 completeNetwork.Add(featureNetwork);
128 completeNetwork.Add(concat);
140 AdvantageNetworkType& advantageNetwork,
141 ValueNetworkType& valueNetwork,
142 const bool isNoisy =
false):
143 featureNetwork(featureNetwork),
144 advantageNetwork(advantageNetwork),
145 valueNetwork(valueNetwork),
149 concat->
Add(valueNetwork);
150 concat->
Add(advantageNetwork);
151 completeNetwork.Add(featureNetwork);
152 completeNetwork.Add(concat);
162 *valueNetwork = *model.valueNetwork;
163 *advantageNetwork = *model.advantageNetwork;
164 *featureNetwork = *model.featureNetwork;
165 isNoisy = model.isNoisy;
166 noisyLayerIndex = model.noisyLayerIndex;
180 void Predict(
const arma::mat state, arma::mat& actionValue)
182 arma::mat advantage, value, networkOutput;
183 completeNetwork.Predict(state, networkOutput);
184 value = networkOutput.row(0);
185 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
186 actionValue = advantage.each_row() + (value - arma::mean(advantage));
195 void Forward(
const arma::mat state, arma::mat& actionValue)
197 arma::mat advantage, value, networkOutput;
198 completeNetwork.Forward(state, networkOutput);
199 value = networkOutput.row(0);
200 advantage = networkOutput.rows(1, networkOutput.n_rows - 1);
201 actionValue = advantage.each_row() +
202 (value - arma::mean(advantage));
203 this->actionValues = actionValue;
213 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
216 lossFunction.
Backward(this->actionValues, target, gradLoss);
218 arma::mat gradValue = arma::sum(gradLoss);
219 arma::mat gradAdvantage = gradLoss.each_row() - arma::mean(gradLoss);
221 arma::mat grad = arma::join_cols(gradValue, gradAdvantage);
222 completeNetwork.Backward(state, grad, gradient);
228 void Reset(
const size_t inputDimensionality = 0)
230 completeNetwork.Reset(inputDimensionality);
238 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
241 (valueNetwork->Network()[noisyLayerIndex[i]]))->
ResetNoise();
243 (advantageNetwork->Network()[noisyLayerIndex[i]]))->
ResetNoise();
248 const arma::mat&
Parameters()
const {
return completeNetwork.Parameters(); }
250 arma::mat&
Parameters() {
return completeNetwork.Parameters(); }
254 CompleteNetworkType completeNetwork;
260 FeatureNetworkType* featureNetwork;
263 AdvantageNetworkType* advantageNetwork;
266 ValueNetworkType* valueNetwork;
272 std::vector<size_t> noisyLayerIndex;
275 arma::mat actionValues;
The mean squared error performance function measures the network's performance according to the mean ...
EmptyLossType< arma::mat > EmptyLoss
Implementation of the NoisyLinear layer class.
Linear algebra utility functions, generally performed on matrices or vectors.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
void Forward(const arma::mat state, arma::mat &actionValue)
Perform the forward pass of the states in real batch mode.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
ConcatType< arma::mat > Concat
Implementation of the Dueling Deep Q-Learning network.
void ResetNoise()
Reset the noise parameters (epsilons).
Implementation of the Linear layer class.
Implementation of the base layer.
DuelingDQN()
Default constructor.
void operator=(const DuelingDQN &model)
Copy assignment operator.
A "multi-layer" is a layer that is a wrapper around other layers.
DuelingDQN(const DuelingDQN &)
Copy constructor.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
DuelingDQN(const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of DuelingDQN class.
DuelingDQN(FeatureNetworkType &featureNetwork, AdvantageNetworkType &advantageNetwork, ValueNetworkType &valueNetwork, const bool isNoisy=false)
Construct an instance of DuelingDQN class from a pre-constructed network.
void Backward(const MatType &prediction, const MatType &target, MatType &loss)
Ordinary feed backward pass of a neural network.
void Reset(const size_t inputDimensionality=0)
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the Concat class.