simple_dqn.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP
13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP
14 
15 #include <mlpack/prereqs.hpp>
20 
21 namespace mlpack {
22 namespace rl {
23 
29 template<
30  typename OutputLayerType = ann::MeanSquaredError,
31  typename InitType = ann::GaussianInitialization,
32  typename NetworkType = ann::FFN<OutputLayerType, InitType>
33 >
34 class SimpleDQN
35 {
36  public:
40  SimpleDQN() : network(), isNoisy(false)
41  { /* Nothing to do here. */ }
42 
53  SimpleDQN(const int h1,
54  const int h2,
55  const int outputDim,
56  const bool isNoisy = false,
57  InitType init = InitType(),
58  OutputLayerType outputLayer = OutputLayerType()):
59  network(outputLayer, init),
60  isNoisy(isNoisy)
61  {
62  network.Add(new ann::Linear(h1));
63  network.Add(new ann::ReLU());
64  if (isNoisy)
65  {
66  noisyLayerIndex.push_back(network.Network().size());
67  network.Add(new ann::NoisyLinear(h2));
68  network.Add(new ann::ReLU());
69  noisyLayerIndex.push_back(network.Network().size());
70  network.Add(new ann::NoisyLinear(outputDim));
71  }
72  else
73  {
74  network.Add(new ann::Linear(h2));
75  network.Add(new ann::ReLU());
76  network.Add(new ann::Linear(outputDim));
77  }
78  }
79 
86  SimpleDQN(NetworkType& network, const bool isNoisy = false):
87  network(network),
88  isNoisy(isNoisy)
89  { /* Nothing to do here. */ }
90 
102  void Predict(const arma::mat state, arma::mat& actionValue)
103  {
104  network.Predict(state, actionValue);
105  }
106 
113  void Forward(const arma::mat state, arma::mat& target)
114  {
115  network.Forward(state, target);
116  }
117 
121  void Reset(const size_t inputDimensionality = 0)
122  {
123  network.Reset(inputDimensionality);
124  }
125 
129  void ResetNoise()
130  {
131  for (size_t i = 0; i < noisyLayerIndex.size(); i++)
132  {
133  dynamic_cast<ann::NoisyLinear*>(
134  network.Network()[noisyLayerIndex[i]])->ResetNoise();
135  }
136  }
137 
139  const arma::mat& Parameters() const { return network.Parameters(); }
141  arma::mat& Parameters() { return network.Parameters(); }
142 
150  void Backward(const arma::mat state, arma::mat& target, arma::mat& gradient)
151  {
152  network.Backward(state, target, gradient);
153  }
154 
155  private:
157  NetworkType network;
158 
160  bool isNoisy;
161 
163  std::vector<size_t> noisyLayerIndex;
164 };
165 
166 } // namespace rl
167 } // namespace mlpack
168 
169 #endif
SimpleDQN(const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Definition: simple_dqn.hpp:53
Implementation of the NoisyLinear layer class.
Definition: noisylinear.hpp:30
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Definition: simple_dqn.hpp:129
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleDQN()
Default constructor.
Definition: simple_dqn.hpp:40
void Reset(const size_t inputDimensionality=0)
Resets the parameters of the network.
Definition: simple_dqn.hpp:121
arma::mat & Parameters()
Modify the Parameters.
Definition: simple_dqn.hpp:141
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Definition: simple_dqn.hpp:113
Implementation of the Linear layer class.
Definition: linear.hpp:42
const arma::mat & Parameters() const
Return the Parameters.
Definition: simple_dqn.hpp:139
Implementation of the base layer.
Definition: base_layer.hpp:66
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
Definition: simple_dqn.hpp:86
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
Definition: simple_dqn.hpp:102
MeanSquaredErrorType< arma::mat > MeanSquaredError
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.
Definition: simple_dqn.hpp:150