MultiheadAttentionType< InputType, OutputType, RegularizerType > Class Template Reference

Multihead Attention allows the model to jointly attend to information from different representation subspaces at different positions. More...

Inheritance diagram for MultiheadAttentionType< InputType, OutputType, RegularizerType >:

Public Member Functions

 MultiheadAttentionType ()
 Default constructor. More...

 
 MultiheadAttentionType (const size_t tgtSeqLen, const size_t srcSeqLen, const size_t embedDim, const size_t numHeads, const InputType &attnmask=InputType(), const InputType &keyPaddingMask=InputType())
 Create the MultiheadAttention object using the specified modules. More...

 
OutputType const & AttentionMask () const
 Get the two dimensional Attention Mask. More...

 
OutputType & AttentionMask ()
 Modify the two dimensional Attention Mask. More...

 
void Backward (const InputType &, const OutputType &gy, OutputType &g)
 Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f. More...

 
MultiheadAttentionTypeClone () const
 Clone the MultiheadAttentionType object. More...

 
size_t EmbedDim () const
 Get the embedding dimension. More...

 
size_t & EmbedDim ()
 Modify the embedding dimension. More...

 
void Forward (const InputType &input, OutputType &output)
 Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f. More...

 
void Gradient (const InputType &input, const OutputType &error, OutputType &gradient)
 Calculate the gradient using the output delta and the input activation. More...

 
size_t InputShape () const
 
OutputType const & KeyPaddingMask () const
 Get Key Padding Mask. More...

 
OutputType & KeyPaddingMask ()
 Modify the Key Padding Mask. More...

 
size_t NumHeads () const
 Get the number of attention heads. More...

 
size_t & NumHeads ()
 Modify the number of attention heads. More...

 
const std::vector< size_t > OutputDimensions () const
 
template
<
typename
Archive
>
void serialize (Archive &ar, const uint32_t)
 Serialize the layer. More...

 
void SetWeights (typename OutputType::elem_type *weightsPtr)
 Reset the layer parameters. More...

 
size_t SrcSeqLen () const
 Get the source sequence length. More...

 
size_t & SrcSeqLen ()
 Modify the source sequence length. More...

 
size_t TgtSeqLen () const
 Get the target sequence length. More...

 
size_t & TgtSeqLen ()
 Modify the target sequence length. More...

 
size_t WeightSize () const
 Get the size of the weights. More...

 
const size_t WeightSize () const
 Get the total number of trainable weights in the layer. More...

 
- Public Member Functions inherited from Layer< InputType, OutputType >
 Layer ()
 Default constructor. More...

 
 Layer (const Layer &layer)
 Copy constructor. This is not responsible for copying weights! More...

 
 Layer (Layer &&layer)
 Move constructor. This is not responsible for moving weights! More...

 
virtual ~Layer ()
 Default deconstructor. More...

 
virtual void Backward (const InputType &, const InputType &, InputType &)
 Performs a backpropagation step through the layer, with respect to the given input. More...

 
virtual void ComputeOutputDimensions ()
 Compute the output dimensions. More...

 
virtual void CustomInitialize (InputType &, const size_t)
 Override the weight matrix of the layer. More...

 
virtual void Forward (const InputType &, InputType &)
 Takes an input object, and computes the corresponding output of the layer. More...

 
virtual void Forward (const InputType &, const InputType &)
 Takes an input and output object, and computes the corresponding loss of the layer. More...

 
virtual void Gradient (const InputType &, const InputType &, InputType &)
 Computing the gradient of the layer with respect to its own input. More...

 
const std::vector< size_t > & InputDimensions () const
 Get the input dimensions. More...

 
std::vector< size_t > & InputDimensions ()
 Modify the input dimensions. More...

 
virtual double Loss ()
 Get the layer loss. More...

 
virtual Layeroperator= (const Layer &layer)
 Copy assignment operator. This is not responsible for copying weights! More...

 
virtual Layeroperator= (Layer &&layer)
 Move assignment operator. This is not responsible for moving weights! More...

 
const std::vector< size_t > & OutputDimensions ()
 Get the output dimensions. More...

 
virtual size_t OutputSize () final
 Get the number of elements in the output from this layer. More...

 
virtual const InputType & Parameters () const
 Get the parameters. More...

 
virtual InputType & Parameters ()
 Set the parameters. More...

 
void serialize (Archive &ar, const uint32_t)
 Serialize the layer. More...

 
virtual void SetWeights (typename InputType ::elem_type *)
 Reset the layer parameter. More...

 
virtual bool const & Training () const
 Get whether the layer is currently in training mode. More...

 
virtual bool & Training ()
 Modify whether the layer is currently in training mode. More...

 

Additional Inherited Members

- Protected Attributes inherited from Layer< InputType, OutputType >
std::vector< size_t > inputDimensions
 Logical input dimensions of each point. More...

 
std::vector< size_t > outputDimensions
 Logical output dimensions of each point. More...

 
bool training
 If true, the layer is in training mode; otherwise, it is in testing mode. More...

 
bool validOutputDimensions
 This is true if ComputeOutputDimensions() has been called, and outputDimensions can be considered to be up-to-date. More...

 

Detailed Description


template
<
typename
InputType
=
arma::mat
,
typename
OutputType
=
arma::mat
,
typename
RegularizerType
=
NoRegularizer
>

class mlpack::ann::MultiheadAttentionType< InputType, OutputType, RegularizerType >

Multihead Attention allows the model to jointly attend to information from different representation subspaces at different positions.

With a single attention head, averaging inhibits this. [arxiv.org:1706.03762v5]

The MultiheadAttention class takes concatenated form of query, key and value. The query, key and value are concatenated into single matrix and fed to the Forward function as input.

The query, key and value are matrices of shapes (embedDim * tgtSeqLen, batchSize), (embedDim * srcSeqLen, batchSize) and (embedDim * srcSeqLen, batchSize) respectively. The output is a matrix of shape (embedDim * tgtSeqLen, batchSize). The embeddings are stored consequently.

Template Parameters
InputTypeType of the input data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
OutputTypeType of the output data (arma::colvec, arma::mat, arma::sp_mat or arma::cube).
RegularizerTypeType of the regularizer to be used.

Definition at line 63 of file multihead_attention.hpp.

Constructor & Destructor Documentation

◆ MultiheadAttentionType() [1/2]

◆ MultiheadAttentionType() [2/2]

MultiheadAttentionType ( const size_t  tgtSeqLen,
const size_t  srcSeqLen,
const size_t  embedDim,
const size_t  numHeads,
const InputType &  attnmask = InputType(),
const InputType &  keyPaddingMask = InputType() 
)

Create the MultiheadAttention object using the specified modules.

Parameters
tgtSeqLenTarget sequence length.
srcSeqLenSource sequence length.
embedDimTotal dimension of the model.
numHeadsNumber of parallel attention heads.
attnMaskTwo dimensional Attention Mask.
keyPaddingMaskKey Padding Mask.

Member Function Documentation

◆ AttentionMask() [1/2]

OutputType const& AttentionMask ( ) const
inline

Get the two dimensional Attention Mask.

Definition at line 163 of file multihead_attention.hpp.

◆ AttentionMask() [2/2]

OutputType& AttentionMask ( )
inline

Modify the two dimensional Attention Mask.

Definition at line 165 of file multihead_attention.hpp.

◆ Backward()

void Backward ( const InputType &  ,
const OutputType &  gy,
OutputType &  g 
)

Ordinary feed backward pass of a neural network, calculating the function f(x) by propagating x backwards trough f.

Using the results from the feed forward pass.

Parameters
gyThe backpropagated error.
gThe calculated gradient.

Referenced by MultiheadAttentionType< InputType, OutputType, RegularizerType >::Clone().

◆ Clone()

◆ EmbedDim() [1/2]

size_t EmbedDim ( ) const
inline

Get the embedding dimension.

Definition at line 153 of file multihead_attention.hpp.

◆ EmbedDim() [2/2]

size_t& EmbedDim ( )
inline

Modify the embedding dimension.

Definition at line 155 of file multihead_attention.hpp.

◆ Forward()

void Forward ( const InputType &  input,
OutputType &  output 
)

Ordinary feed forward pass of a neural network, evaluating the function f(x) by propagating the activity forward through f.

Parameters
inputThe query matrix.
outputResulting output activation.

Referenced by MultiheadAttentionType< InputType, OutputType, RegularizerType >::Clone().

◆ Gradient()

void Gradient ( const InputType &  input,
const OutputType &  error,
OutputType &  gradient 
)

Calculate the gradient using the output delta and the input activation.

Parameters
inputThe input data used for evaluating specified function.
errorThe calculated error.
gradientThe calculated gradient.

Referenced by MultiheadAttentionType< InputType, OutputType, RegularizerType >::Clone().

◆ InputShape()

size_t InputShape ( ) const
inline

Definition at line 185 of file multihead_attention.hpp.

◆ KeyPaddingMask() [1/2]

OutputType const& KeyPaddingMask ( ) const
inline

Get Key Padding Mask.

Definition at line 168 of file multihead_attention.hpp.

◆ KeyPaddingMask() [2/2]

OutputType& KeyPaddingMask ( )
inline

Modify the Key Padding Mask.

Definition at line 170 of file multihead_attention.hpp.

◆ NumHeads() [1/2]

size_t NumHeads ( ) const
inline

Get the number of attention heads.

Definition at line 158 of file multihead_attention.hpp.

◆ NumHeads() [2/2]

size_t& NumHeads ( )
inline

Modify the number of attention heads.

Definition at line 160 of file multihead_attention.hpp.

◆ OutputDimensions()

const std::vector<size_t> OutputDimensions ( ) const
inline

◆ serialize()

void serialize ( Archive &  ar,
const uint32_t   
)

◆ SetWeights()

void SetWeights ( typename OutputType::elem_type *  weightsPtr)

◆ SrcSeqLen() [1/2]

size_t SrcSeqLen ( ) const
inline

Get the source sequence length.

Definition at line 148 of file multihead_attention.hpp.

◆ SrcSeqLen() [2/2]

size_t& SrcSeqLen ( )
inline

Modify the source sequence length.

Definition at line 150 of file multihead_attention.hpp.

◆ TgtSeqLen() [1/2]

size_t TgtSeqLen ( ) const
inline

Get the target sequence length.

Definition at line 143 of file multihead_attention.hpp.

◆ TgtSeqLen() [2/2]

size_t& TgtSeqLen ( )
inline

Modify the target sequence length.

Definition at line 145 of file multihead_attention.hpp.

◆ WeightSize() [1/2]

size_t WeightSize ( ) const
inlinevirtual

Get the size of the weights.

Reimplemented from Layer< InputType, OutputType >.

Definition at line 134 of file multihead_attention.hpp.

References MultiheadAttentionType< InputType, OutputType, RegularizerType >::serialize().

◆ WeightSize() [2/2]

const size_t WeightSize ( ) const
inlinevirtual

Get the total number of trainable weights in the layer.

Reimplemented from Layer< InputType, OutputType >.

Definition at line 172 of file multihead_attention.hpp.


The documentation for this class was generated from the following file: