GAN< Model, InitializationRuleType, Noise, PolicyType > Class Template Reference

The implementation of the standard GAN module. More...

Public Member Functions

 GAN (arma::mat &trainData, Model generator, Model discriminator, InitializationRuleType &initializeRule, Noise &noiseFunction, const size_t noiseDim, const size_t batchSize, const size_t generatorUpdateStep, const size_t preTrainSize, const double multiplier, const double clippingParameter=0.01, const double lambda=10.0)
 Constructor for GAN class. More...

 
 GAN (const GAN &)
 Copy constructor. More...

 
 GAN (GAN &&)
 Move constructor. More...

 
const Model & Discriminator () const
 Return the discriminator of the GAN. More...

 
Model & Discriminator ()
 Modify the discriminator of the GAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the Standard GAN and DCGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type Evaluate (const arma::mat &parameters, const size_t i, const size_t batchSize)
 Evaluate function for the WGAN-GP. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the Standard GAN and DCGAN. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN. More...

 
template
<
typename
GradType
,
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, double >::type EvaluateWithGradient (const arma::mat &parameters, const size_t i, GradType &gradient, const size_t batchSize)
 EvaluateWithGradient function for the WGAN-GP. More...

 
void Forward (arma::mat &&input)
 This function does a forward pass through the GAN network. More...

 
const Model & Generator () const
 Return the generator of the GAN. More...

 
Model & Generator ()
 Modify the generator of the GAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, StandardGAN >::value||std::is_same< Policy, DCGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for Standard GAN and DCGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGAN >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN. More...

 
template
<
typename
Policy
=
PolicyType
>
std::enable_if< std::is_same< Policy, WGANGP >::value, void >::type Gradient (const arma::mat &parameters, const size_t i, arma::mat &gradient, const size_t batchSize)
 Gradient function for WGAN-GP. More...

 
size_t NumFunctions () const
 Return the number of separable functions (the number of predictor points). More...

 
const arma::mat & Parameters () const
 Return the parameters of the network. More...

 
arma::mat & Parameters ()
 Modify the parameters of the network. More...

 
void Predict (arma::mat &&input, arma::mat &output)
 This function predicts the output of the network on the given input. More...

 
const arma::mat & Predictors () const
 Get the matrix of data points (predictors). More...

 
arma::mat & Predictors ()
 Modify the matrix of data points (predictors). More...

 
void Reset ()
 
const arma::mat & Responses () const
 Get the matrix of responses to the input data points. More...

 
arma::mat & Responses ()
 Modify the matrix of responses to the input data points. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const unsigned int)
 Serialize the model. More...

 
void Shuffle ()
 Shuffle the order of function visitation. More...

 
template
<
typename
OptimizerType
>
double Train (OptimizerType &Optimizer)
 Train function. More...

 

Detailed Description


template
<
typename
Model
,
typename
InitializationRuleType
,
typename
Noise
,
typename
PolicyType
=
StandardGAN
>

class mlpack::ann::GAN< Model, InitializationRuleType, Noise, PolicyType >

The implementation of the standard GAN module.

Generative Adversarial Networks (GANs) are a class of artificial intelligence algorithms used in unsupervised machine learning, implemented by a system of two neural networks contesting with each other in a zero-sum game framework. This technique can generate photographs that look at least superficially authentic to human observers, having many realistic characteristics. GANs have been used in Text-to-Image Synthesis, Medical Drug Discovery, High Resolution Imagery Generation, Neural Machine Translation and so on.

For more information, see the following paper:

@article{Goodfellow14,
author = {Ian J. Goodfellow, Jean Pouget-Abadi, Mehdi Mirza, Bing Xu,
David Warde-Farley, Sherjil Ozair, Aaron Courville and
Yoshua Bengio},
title = {Generative Adversarial Nets},
year = {2014},
url = {http://arxiv.org/abs/1406.2661},
eprint = {1406.2661},
}
Template Parameters
ModelThe class type of Generator and Discriminator.
InitializationRuleTypeType of Initializer.
NoiseThe noise function to use.
PolicyTypeThe GAN variant to be used (GAN, DCGAN, WGAN or WGANGP).

Definition at line 63 of file gan.hpp.

Constructor & Destructor Documentation

◆ GAN() [1/3]

GAN ( arma::mat &  trainData,
Model  generator,
Model  discriminator,
InitializationRuleType &  initializeRule,
Noise &  noiseFunction,
const size_t  noiseDim,
const size_t  batchSize,
const size_t  generatorUpdateStep,
const size_t  preTrainSize,
const double  multiplier,
const double  clippingParameter = 0.01,
const double  lambda = 10.0 
)

Constructor for GAN class.

Parameters
trainDataThe real data.
generatorGenerator network.
discriminatorDiscriminator network.
batchSizeBatch size to be used for training.
generatorUpdateStepNumber of steps to train Discriminator before updating Generator.
preTrainSizeNumber of pre-training steps of Discriminator.
multiplierRatio of learning rate of Discriminator to the Generator.
clippingParameterWeight range for enforcing Lipschitz constraint.
lambdaParameter for setting the gradient penalty.

◆ GAN() [2/3]

GAN ( const GAN< Model, InitializationRuleType, Noise, PolicyType > &  )

Copy constructor.

◆ GAN() [3/3]

GAN ( GAN< Model, InitializationRuleType, Noise, PolicyType > &&  )

Move constructor.

Member Function Documentation

◆ Discriminator() [1/2]

const Model& Discriminator ( ) const
inline

Return the discriminator of the GAN.

Definition at line 295 of file gan.hpp.

◆ Discriminator() [2/2]

Model& Discriminator ( )
inline

Modify the discriminator of the GAN.

Definition at line 297 of file gan.hpp.

◆ Evaluate() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN.

This function gives the performance of the WGAN on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ Evaluate() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type Evaluate ( const arma::mat &  parameters,
const size_t  i,
const size_t  batchSize 
)

Evaluate function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the Standard GAN and DCGAN.

This function gives the performance of the Standard GAN or DCGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN.

This function gives the performance of the WGAN on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ EvaluateWithGradient() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, double>::type EvaluateWithGradient ( const arma::mat &  parameters,
const size_t  i,
GradType &  gradient,
const size_t  batchSize 
)

EvaluateWithGradient function for the WGAN-GP.

This function gives the performance of the WGAN-GP on the current input, while updating Gradients.

Parameters
parametersThe parameters of the network.
iIndex of the current input.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Forward()

void Forward ( arma::mat &&  input)

This function does a forward pass through the GAN network.

Parameters
inputSampled noise.

◆ Generator() [1/2]

const Model& Generator ( ) const
inline

Return the generator of the GAN.

Definition at line 291 of file gan.hpp.

◆ Generator() [2/2]

Model& Generator ( )
inline

Modify the generator of the GAN.

Definition at line 293 of file gan.hpp.

◆ Gradient() [1/3]

std::enable_if<std::is_same<Policy, StandardGAN>::value || std::is_same<Policy, DCGAN>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for Standard GAN and DCGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [2/3]

std::enable_if<std::is_same<Policy, WGAN>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ Gradient() [3/3]

std::enable_if<std::is_same<Policy, WGANGP>::value, void>::type Gradient ( const arma::mat &  parameters,
const size_t  i,
arma::mat &  gradient,
const size_t  batchSize 
)

Gradient function for WGAN-GP.

This function passes the gradient based on which network is being trained, i.e., Generator or Discriminator.

Parameters
parameterspresent parameters of the network.
iIndex of the predictors.
gradientVariable to store the present gradient.
batchSizeVariable to store the present number of inputs.

◆ NumFunctions()

size_t NumFunctions ( ) const
inline

Return the number of separable functions (the number of predictor points).

Definition at line 300 of file gan.hpp.

◆ Parameters() [1/2]

const arma::mat& Parameters ( ) const
inline

Return the parameters of the network.

Definition at line 286 of file gan.hpp.

◆ Parameters() [2/2]

arma::mat& Parameters ( )
inline

Modify the parameters of the network.

Definition at line 288 of file gan.hpp.

◆ Predict()

void Predict ( arma::mat &&  input,
arma::mat &  output 
)

This function predicts the output of the network on the given input.

Parameters
inputThe input the Discriminator network.
outputResult of the Discriminator network.

◆ Predictors() [1/2]

const arma::mat& Predictors ( ) const
inline

Get the matrix of data points (predictors).

Definition at line 308 of file gan.hpp.

◆ Predictors() [2/2]

arma::mat& Predictors ( )
inline

Modify the matrix of data points (predictors).

Definition at line 310 of file gan.hpp.

References GAN< Model, InitializationRuleType, Noise, PolicyType >::serialize().

◆ Reset()

void Reset ( )

◆ Responses() [1/2]

const arma::mat& Responses ( ) const
inline

Get the matrix of responses to the input data points.

Definition at line 303 of file gan.hpp.

◆ Responses() [2/2]

arma::mat& Responses ( )
inline

Modify the matrix of responses to the input data points.

Definition at line 305 of file gan.hpp.

◆ serialize()

void serialize ( Archive &  ar,
const unsigned  int 
)

◆ Shuffle()

void Shuffle ( )

Shuffle the order of function visitation.

This may be called by the optimizer.

◆ Train()

double Train ( OptimizerType &  Optimizer)

Train function.

Returns
The final objective of the trained model (NaN or Inf on error).

The documentation for this class was generated from the following file:
  • /home/jenkins-mlpack/mlpack.org/_src/mlpack-3.2.1/src/mlpack/methods/ann/gan/gan.hpp