simple_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace rl {
23 
24 using namespace mlpack::ann;
25 
31 template<
32  typename OutputLayerType = MeanSquaredError<>,
33  typename InitType = GaussianInitialization,
34  typename NetworkType = FFN<OutputLayerType, InitType>
35 >
36 class SimpleDQN
37 {
38  public:
42  SimpleDQN() : network(), isNoisy(false)
43  { /* Nothing to do here. */ }
44 
56  SimpleDQN(const int inputDim,
57  const int h1,
58  const int h2,
59  const int outputDim,
60  const bool isNoisy = false,
61  InitType init = InitType(),
62  OutputLayerType outputLayer = OutputLayerType()):
63  network(outputLayer, init),
64  isNoisy(isNoisy)
65  {
66  network.Add(new Linear<>(inputDim, h1));
67  network.Add(new ReLULayer<>());
68  if (isNoisy)
69  {
70  noisyLayerIndex.push_back(network.Model().size());
71  network.Add(new NoisyLinear<>(h1, h2));
72  network.Add(new ReLULayer<>());
73  noisyLayerIndex.push_back(network.Model().size());
74  network.Add(new NoisyLinear<>(h2, outputDim));
75  }
76  else
77  {
78  network.Add(new Linear<>(h1, h2));
79  network.Add(new ReLULayer<>());
80  network.Add(new Linear<>(h2, outputDim));
81  }
82  }
83 
90  SimpleDQN(NetworkType& network, const bool isNoisy = false):
91  network(network),
92  isNoisy(isNoisy)
93  { /* Nothing to do here. */ }
94 
106  void Predict(const arma::mat state, arma::mat& actionValue)
107  {
108  network.Predict(state, actionValue);
109  }
110 
117  void Forward(const arma::mat state, arma::mat& target)
118  {
119  network.Forward(state, target);
120  }
121 
126  {
127  network.ResetParameters();
128  }
129 
133  void ResetNoise()
134  {
135  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
136  {
137  boost::get<NoisyLinear<>*>
138  (network.Model()[noisyLayerIndex[i]])->ResetNoise();
139  }
140  }
141 
143  const arma::mat& Parameters() const { return network.Parameters(); }
145  arma::mat& Parameters() { return network.Parameters(); }
146 
154  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
155  {
156  network.Backward(state, target, gradient);
157  }
158 
159  private:
161  NetworkType network;
162 
164  bool isNoisy;
165 
167  std::vector<size_t> noisyLayerIndex;
168 };
169 
170 } // namespace rl
171 } // namespace mlpack
172 
173 #endif
Artificial Neural Network.
void ResetParameters()
Resets the parameters of the network.
Definition: simple_dqn.hpp:125
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: simple_dqn.hpp:133
SimpleDQN(const int inputDim, const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Definition: simple_dqn.hpp:56
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleDQN()
Default constructor.
Definition: simple_dqn.hpp:42
arma::mat & Parameters()
Modify the Parameters.
Definition: simple_dqn.hpp:145
The core includes that mlpack expects; standard C++ includes and Armadillo.
Implementation of the Linear layer class.
Definition: layer_types.hpp:89
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Definition: simple_dqn.hpp:117
const arma::mat & Parameters() const
Return the Parameters.
Definition: simple_dqn.hpp:143
Implementation of the base layer.
Definition: base_layer.hpp:69
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
Definition: simple_dqn.hpp:90
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: simple_dqn.hpp:106
Implementation of the NoisyLinear layer class.
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: simple_dqn.hpp:154
The mean squared error performance function measures the network&#39;s performance according to the mean ...
Implementation of a standard feed forward network.
Definition: ffn.hpp:52
This class is used to initialize weigth matrix with a gaussian.