Reparametrization< InputDataType, OutputDataType > Class Template Reference

Implementation of the Reparametrization layer class. More...

Public Member Functions

 Reparametrization ()
 Create the Reparametrization object. More...

 
 Reparametrization (const size_t latentSize, const bool stochastic=true, const bool includeKl=true, const double beta=1)
 Create the Reparametrization layer object using the specified sample vector size. More...

 
template
<
typename
eT
>
void Backward (const arma::Mat< eT > &&input, arma::Mat< eT > &&gy, arma::Mat< eT > &&g)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f. More...

 
OutputDataType const & Delta () const
 Get the delta. More...

 
OutputDataType & Delta ()
 Modify the delta. More...

 
template
<
typename
eT
>
void Forward (const arma::Mat< eT > &&input, arma::Mat< eT > &&output)
 Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More...

 
double Loss ()
 Get the KL divergence with standard normal. More...

 
OutputDataType const & OutputParameter () const
 Get the output parameter. More...

 
OutputDataType & OutputParameter ()
 Modify the output parameter. More...

 
size_t const & OutputSize () const
 Get the output size. More...

 
size_t & OutputSize ()
 Modify the output size. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const unsigned int)
 Serialize the layer. More...

 

Detailed Description


template
<
typename
InputDataType
=
arma::mat
,
typename
OutputDataType
=
arma::mat
>

class mlpack::ann::Reparametrization< InputDataType, OutputDataType >

Implementation of the Reparametrization layer class.

This layer samples from the given parameters of a normal distribution.

This class also supports beta-VAE, a state-of-the-art framework for automated discovery of interpretable factorised latent representations from raw image data in a completely unsupervised manner.

For more information, refer the following paper.

@article{ICLR2017,
title = {beta-VAE: Learning basic visual concepts with a constrained
variational framework},
author = {Irina Higgins, Loic Matthey, Arka Pal, Christopher Burgess,
Xavier Glorot, Matthew Botvinick, Shakir Mohamed and
Alexander Lerchner | Google DeepMind},
journal = {2017 International Conference on Learning Representations(ICLR)},
year = {2017}
}
Template Parameters
InputDataTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputDataTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).

Definition at line 93 of file layer_types.hpp.

Constructor & Destructor Documentation

◆ Reparametrization() [1/2]

Create the Reparametrization object.

◆ Reparametrization() [2/2]

Reparametrization ( const size_t  latentSize,
const bool  stochastic = true,
const bool  includeKl = true,
const double  beta = 1 
)

Create the Reparametrization layer object using the specified sample vector size.

Parameters
latentSizeThe number of output latent units.
stochasticWhether we want random sample or constant.
includeKlWhether we want to include KL loss in backward function.
betaThe beta (hyper)parameter for beta-VAE mentioned above.

Member Function Documentation

◆ Backward()

void Backward ( const arma::Mat< eT > &&  input,
arma::Mat< eT > &&  gy,
arma::Mat< eT > &&  g 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f.

Using the results from the feed forward pass.

Parameters
inputThe propagated input activation.
gyThe backpropagated error.
gThe calculated gradient.

◆ Delta() [1/2]

OutputDataType const& Delta ( ) const
inline

Get the delta.

Definition at line 104 of file reparametrization.hpp.

◆ Delta() [2/2]

OutputDataType& Delta ( )
inline

Modify the delta.

Definition at line 106 of file reparametrization.hpp.

◆ Forward()

void Forward ( const arma::Mat< eT > &&  input,
arma::Mat< eT > &&  output 
)

Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.

Parameters
inputInput data used for evaluating the specified function.
outputResulting output activation.

◆ Loss()

double Loss ( )
inline

Get the KL divergence with standard normal.

Definition at line 114 of file reparametrization.hpp.

References Reparametrization< InputDataType, OutputDataType >::serialize().

◆ OutputParameter() [1/2]

OutputDataType const& OutputParameter ( ) const
inline

Get the output parameter.

Definition at line 99 of file reparametrization.hpp.

◆ OutputParameter() [2/2]

OutputDataType& OutputParameter ( )
inline

Modify the output parameter.

Definition at line 101 of file reparametrization.hpp.

◆ OutputSize() [1/2]

size_t const& OutputSize ( ) const
inline

Get the output size.

Definition at line 109 of file reparametrization.hpp.

◆ OutputSize() [2/2]

size_t& OutputSize ( )
inline

Modify the output size.

Definition at line 111 of file reparametrization.hpp.

◆ serialize()

void serialize ( Archive &  ar,
const unsigned  int 
)

The documentation for this class was generated from the following files:
  • /home/jenkins-mlpack/mlpack.org/_src/mlpack-3.2.1/src/mlpack/methods/ann/layer/layer_types.hpp
  • /home/jenkins-mlpack/mlpack.org/_src/mlpack-3.2.1/src/mlpack/methods/ann/layer/reparametrization.hpp