13 #ifndef MLPACK_METHODS_ANN_LOSS_FUNCTIONS_VR_CLASS_REWARD_HPP 14 #define MLPACK_METHODS_ANN_LOSS_FUNCTIONS_VR_CLASS_REWARD_HPP 30 template<
typename MatType = arma::mat>
50 typename MatType::elem_type
Forward(
const MatType& input,
51 const MatType& target);
64 void Backward(
const MatType& input,
const MatType& target, MatType& output);
71 template <
typename LayerType,
typename... Args>
72 void Add(Args... args) { network.push_back(
new LayerType(args...)); }
81 network.push_back(layer);
85 const std::vector<Layer<MatType>*>&
Network()
const {
return network; }
87 std::vector<Layer<MatType>*>&
Network() {
return network; }
93 double Scale()
const {
return scale; }
98 template<
typename Archive>
99 void serialize(Archive& ,
const uint32_t );
112 std::vector<Layer<MatType>*> network;
122 #include "vr_class_reward_impl.hpp" void serialize(Archive &, const uint32_t)
Serialize the layer.
VRClassRewardType< arma::mat > VRClassReward
Linear algebra utility functions, generally performed on matrices or vectors.
MatType::elem_type Forward(const MatType &input, const MatType &target)
Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activ...
The core includes that mlpack expects; standard C++ includes and Armadillo.
double Scale() const
Get the value of scale parameter.
bool SizeAverage() const
Get the value of parameter sizeAverage.
VRClassRewardType(const double scale=1, const bool sizeAverage=true)
Create the VRClassRewardType object.
const std::vector< Layer< MatType > * > & Network() const
Get the network.
void Add(Args... args)
Add a new module to the model.
Implementation of the variance reduced classification reinforcement layer.
void Backward(const MatType &input, const MatType &target, MatType &output)
Ordinary feed backward pass of a neural network.
A layer is an abstract class implementing common neural networks operations, such as convolution...
void Add(Layer< MatType > *layer)
Add a new module to the model.
std::vector< Layer< MatType > * > & Network()
Modify the network.