12 #ifndef MLPACK_METHODS_RL_SIMPLE_DQN_HPP 13 #define MLPACK_METHODS_RL_SIMPLE_DQN_HPP 31 typename InitType = ann::GaussianInitialization,
32 typename NetworkType = ann::FFN<OutputLayerType, InitType>
56 const bool isNoisy =
false,
57 InitType init = InitType(),
58 OutputLayerType outputLayer = OutputLayerType()):
59 network(outputLayer, init),
66 noisyLayerIndex.push_back(network.Network().size());
69 noisyLayerIndex.push_back(network.Network().size());
86 SimpleDQN(NetworkType& network,
const bool isNoisy =
false):
102 void Predict(
const arma::mat state, arma::mat& actionValue)
104 network.Predict(state, actionValue);
113 void Forward(
const arma::mat state, arma::mat& target)
115 network.Forward(state, target);
121 void Reset(
const size_t inputDimensionality = 0)
123 network.Reset(inputDimensionality);
131 for (
size_t i = 0; i < noisyLayerIndex.size(); i++)
134 network.Network()[noisyLayerIndex[i]])->
ResetNoise();
139 const arma::mat&
Parameters()
const {
return network.Parameters(); }
150 void Backward(
const arma::mat state, arma::mat& target, arma::mat& gradient)
152 network.Backward(state, target, gradient);
163 std::vector<size_t> noisyLayerIndex;
SimpleDQN(const int h1, const int h2, const int outputDim, const bool isNoisy=false, InitType init=InitType(), OutputLayerType outputLayer=OutputLayerType())
Construct an instance of SimpleDQN class.
Implementation of the NoisyLinear layer class.
void ResetNoise()
Resets noise of the network, if the network is of type noisy.
Linear algebra utility functions, generally performed on matrices or vectors.
SimpleDQN()
Default constructor.
void Reset(const size_t inputDimensionality=0)
Resets the parameters of the network.
arma::mat & Parameters()
Modify the Parameters.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Forward(const arma::mat state, arma::mat &target)
Perform the forward pass of the states in real batch mode.
Implementation of the Linear layer class.
const arma::mat & Parameters() const
Return the Parameters.
Implementation of the base layer.
SimpleDQN(NetworkType &network, const bool isNoisy=false)
Construct an instance of SimpleDQN class from a pre-constructed network.
void Predict(const arma::mat state, arma::mat &actionValue)
Predict the responses to a given set of predictors.
MeanSquaredErrorType< arma::mat > MeanSquaredError
void Backward(const arma::mat state, arma::mat &target, arma::mat &gradient)
Perform the backward pass of the state in real batch mode.