mlpack

🔗 SoftmaxRegression

The SoftmaxRegression class implements an L2-regularized multi-class softmax regression classifier for numerical data. This is a multi-class extension of the logistic regression classifier. By default, L-BFGS is used to learn the model. The SoftmaxRegression class offers easy configurability, and arbitrary optimizers can be used to learn the model.

Softmax regression is useful for multi-class classification (i.e. classes are 0, 1, 2). For two-class situations, see also the LogisticRegression class.

Simple usage example:

// Train a softmax regression model on random data and predict labels:

// All data and labels are uniform random; 5 dimensional data, 4 classes.
// Replace with a data::Load() call or similar for a real application.
arma::mat dataset(5, 1000, arma::fill::randu); // 1000 points.
arma::Row<size_t> labels =
    arma::randi<arma::Row<size_t>>(1000, arma::distr_param(0, 3));
arma::mat testDataset(5, 500, arma::fill::randu); // 500 test points.

mlpack::SoftmaxRegression sr;          // Step 1: create model.
sr.Train(dataset, labels, 4);          // Step 2: train model.
arma::Row<size_t> predictions;
sr.Classify(testDataset, predictions); // Step 3: classify points.

// Print some information about the test predictions.
std::cout << arma::accu(predictions == 2) << " test points classified as class "
    << "2." << std::endl;

More examples...

See also:

🔗 Constructors




Constructor Parameters:

name type description default
data arma::mat Column-major training matrix. (N/A)
labels arma::Row<size_t> Training labels, between 0 and numClasses - 1 (inclusive). Should have length data.n_cols. (N/A)
numClasses size_t Number of classes in the dataset. (N/A)
lambda double L2 regularization penalty parameter. Must be nonnegative. 0.0001
fitIntercept bool If true, an intercept term is fit to the model. true
optimizer any ensmallen optimizer Instantiated ensmallen optimizer for differentiable functions or differentiable separable functions. (N/A)
callbacks... any set of ensmallen callbacks Optional callbacks for the ensmallen optimizer, such as e.g. ens::ProgressBar(), ens::Report(), or others. (N/A)

As an alternative to passing lambda, it can be set with a standalone method:

It is not possible to set fitIntercept except in the constructor or the call to Train().

Note: Setting lambda too small may cause the model to overfit; however, setting it too large may cause the model to underfit. Automatic hyperparameter tuning can be used to find a good value of lambda instead of a manual setting.

🔗 Training

If training is not done as part of the constructor call, it can be done with the Train() function:



Types of each argument are the same as in the table for constructors above.

Notes:

🔗 Classification

Once a SoftmaxRegression model is trained, the Classify() member function can be used to make class predictions for new data.





Classification Parameters:

usage name type description
single-point point arma::vec Single point for classification.
single-point prediction size_t& size_t to store class prediction into.
single-point probabilitiesVec arma::vec& arma::vec& to store class probabilities into; will have length 2.
       
multi-point data arma::mat Set of column-major points for classification.
multi-point predictions arma::Row<size_t>& Vector of size_ts to store class prediction into; will be set to length data.n_cols.
multi-point probabilities arma::mat& Matrix to store class probabilities into (number of rows will be equal to 2; number of columns will be equal to data.n_cols).

🔗 Other Functionality

For complete functionality, the source code can be consulted. Each method is fully documented.

🔗 Simple Examples

See also the simple usage example for a trivial usage of the SoftmaxRegression class.


Train a softmax regression model using a custom SGD-like optimizer with callbacks.

// See https://datasets.mlpack.org/mnist.train.csv.
arma::mat dataset;
mlpack::data::Load("mnist.train.csv", dataset, true);
// See https://datasets.mlpack.org/mnist.train.labels.csv.
arma::Row<size_t> labels;
mlpack::data::Load("mnist.train.labels.csv", labels, true);

mlpack::SoftmaxRegression sr;

// Create AdaGrad optimizer with custom step size and batch size.
ens::AdaGrad optimizer(0.0005 /* step size */, 16 /* batch size */);
optimizer.MaxIterations() = 10 * dataset.n_cols; // Allow 10 epochs.

// Print a progress bar and an optimization report when training is finished.
sr.Train(dataset, labels, 10 /* numClasses */, optimizer,
    0.01 /* lambda */, true /* fit intercept */, ens::ProgressBar(),
    ens::Report());

// Now predict on test labels and compute accuracy.

// See https://datasets.mlpack.org/mnist.test.csv.
arma::mat testDataset;
mlpack::data::Load("mnist.test.csv", testDataset, true);
// See https://datasets.mlpack.org/mnist.test.labels.csv.
arma::Row<size_t> testLabels;
mlpack::data::Load("mnist.test.labels.csv", testLabels, true);

std::cout << std::endl;
std::cout << "Accuracy on training set: "
    << sr.ComputeAccuracy(dataset, labels) << "\%." << std::endl;
std::cout << "Accuracy on test set:     "
    << sr.ComputeAccuracy(testDataset, testLabels) << "\%." << std::endl;

Train a softmax regression model with AdaGrad and save the model every epoch using a custom ensmallen callback:

// This callback saves the model into "model-<epoch>.bin" after every epoch.
class ModelCheckpoint
{
 public:
  ModelCheckpoint(mlpack::SoftmaxRegression<>& model) : model(model) { }

  template<typename OptimizerType, typename FunctionType, typename MatType>
  bool EndEpoch(OptimizerType& /* optimizer */,
                FunctionType& /* function */,
                const MatType& /* coordinates */,
                const size_t epoch,
                const double /* objective */)
  {
    const std::string filename = "model-" + std::to_string(epoch) + ".bin";
    mlpack::data::Save(filename, "sr_model", model, true);
    return false; // Do not terminate the optimization.
  }

 private:
  mlpack::SoftmaxRegression<>& model;
};

With that callback available, the code to train the model is below:

// See https://datasets.mlpack.org/mnist.train.csv.
arma::mat dataset;
mlpack::data::Load("mnist.train.csv", dataset, true);
// See https://datasets.mlpack.org/mnist.train.labels.csv.
arma::Row<size_t> labels;
mlpack::data::Load("mnist.train.labels.csv", labels, true);

mlpack::SoftmaxRegression sr;

// Create AdaGrad optimizer with small step size and batch size of 32.
ens::AdaGrad adaGrad(0.0005, 32);
adaGrad.MaxIterations() = 10 * dataset.n_cols; // 10 epochs maximum.

// Use the custom callback and an L2 penalty parameter of 0.01.
sr.Train(dataset, labels, 10 /* numClasses */, adaGrad, 0.01, true,
    ModelCheckpoint(sr), ens::ProgressBar());

// Now files like model-1.bin, model-2.bin, etc. should be saved on disk.

Load an existing softmax regression model and print some information about it.

mlpack::SoftmaxRegression sr;
// This assumes that a model called "sr_model" has been saved to the file
// "model-1.bin" (as in the previous example).
mlpack::data::Load("model-1.bin", "sr_model", sr, true);

// Print the dimensionality of the model and some other statistics.
const size_t dimensionality = (sr.FitIntercept()) ?
    (sr.Parameters().n_cols - 1) : (sr.Parameters().n_cols);
std::cout << "The dimensionality of the model in model-1.bin is "
    << dimensionality << "." << std::endl;
std::cout << "The bias parameters for the model, for each class, are: "
    << std::endl;
std::cout << sr.Parameters().col(0).t();

arma::vec point(dimensionality, arma::fill::randu);
std::cout << "The predicted class for a random point is " << sr.Classify(point)
    << "." << std::endl;

Perform incremental training on multiple datasets with multiple calls to Train().

// Generate two random datasets with four classes.
arma::mat firstDataset(5, 1000, arma::fill::randu); // 1000 points.
arma::Row<size_t> firstLabels =
    arma::randi<arma::Row<size_t>>(1000, arma::distr_param(0, 3));

arma::mat secondDataset(5, 1500, arma::fill::randu); // 1500 points.
arma::Row<size_t> secondLabels =
    arma::randi<arma::Row<size_t>>(1500, arma::distr_param(0, 3));

// Train a model on the first dataset with an L2 regularization penalty
// parameter of 0.01, not fitting an intercept.
mlpack::SoftmaxRegression sr(firstDataset, firstLabels, 4, 0.01, false);

// Now compute the accuracy on the second dataset and print it.
std::cout << "Accuracy on second dataset: "
    << sr.ComputeAccuracy(secondDataset, secondLabels) << "\%." << std::endl;

// Train for a second round on the second dataset.
sr.Train(secondDataset, secondLabels, 4);

// Now compute the accuracy on the second dataset again and print it.
// (Note that it may not be all that much better because this is random data!)
std::cout << "Accuracy on second dataset after second training: "
    << sr.ComputeAccuracy(secondDataset, secondLabels) << "\%." << std::endl;

🔗 Advanced Functionality: Different Element Types

The SoftmaxRegression class has one template parameter that can be used to control the element type of the model. The full signature of the class is:

SoftmaxRegression<MatType>

MatType specifies the type of matrix used for training data and internal representation of model parameters. Any matrix type that implements the Armadillo API can be used. The example below trains a softmax regression model on sparse 32-bit floating point data.

// Create random, sparse 100-dimensional data, with 3 classes.
arma::sp_fmat dataset;
dataset.sprandu(100, 5000, 0.3);
arma::Row<size_t> labels =
    arma::randi<arma::Row<size_t>>(5000, arma::distr_param(0, 2));

// Train with L2 regularization penalty parameter of 0.1.
mlpack::SoftmaxRegression<arma::sp_fmat> sr(dataset, labels, 3, 0.1);

// Now classify a test point.
arma::sp_fvec point;
point.sprandu(100, 1, 0.3);

size_t prediction;
arma::fvec probabilitiesVec;
sr.Classify(point, prediction, probabilitiesVec);

std::cout << "Prediction for random test point: " << prediction << "."
    << std::endl;
std::cout << "Class probabilities for random test point: "
    << probabilitiesVec.t();

Note: if MatType is a sparse object (e.g. sp_fmat), the internal parameter representation will be a dense matrix containing elements of the same type (e.g. fmat). This is because L2-regularized softmax regression, even when training on sparse data, does not necessarily produce sparse models.