mlpack

🔗 LARS

The LARS class implements the least-angle regression (LARS) algorithm for L1-penalized and L2-penalized linear regression. LARS can also solve the LASSO (least absolute shrinkage and selection operator) problem. The LARS algorithm is a path algorithm, and thus will recover solutions for all L1 penalty parameters greater than or equal to the given L1 penalty parameter.

Simple usage example:

// Train a LARS model on random numeric data and make predictions.

// All data and responses are uniform random; this uses 10 dimensional data.
// Replace with a data::Load() call or similar for a real application.
arma::mat dataset(10, 1000, arma::fill::randu); // 1000 points.
arma::rowvec responses = arma::randn<arma::rowvec>(1000);
arma::mat testDataset(10, 500, arma::fill::randu); // 500 test points.

mlpack::LARS lars(true, 0.1 /* L1 penalty */); // Step 1: create model.
lars.Train(dataset, responses);                // Step 2: train model.
arma::rowvec predictions;
lars.Predict(testDataset, predictions);        // Step 3: use model to predict.

// Print some information about the test predictions.
std::cout << arma::accu(predictions > 0.7) << " test points predicted to have"
    << " responses greater than 0.7." << std::endl;
std::cout << arma::accu(predictions < 0) << " test points predicted to have "
    << "negative responses." << std::endl;

More examples...

See also:

🔗 Constructors




Constructor Parameters:

name type description default
data arma::mat Training matrix. (N/A)
responses arma::rowvec Training responses (e.g. values to predict). Should have length data.n_cols. (N/A)
colMajor bool Should be set to true if data is column-major. Passing row-major data can avoid a transpose operation. false
useCholesky bool If true, use the Cholesky decomposition of the Gram matrix to solve linear systems (as opposed to the full Gram matrix). false
gramMatrix arma::mat Precomputed Gram matrix of data (i.e. data * data.t() for column-major data). (N/A)
lambda1 double L1 regularization penalty parameter. 0.0
lambda2 double L2 regularization penalty parameter. 0.0
tolerance double Tolerance on feature correlations for convergence. 1e-16
fitIntercept bool If true, an intercept term will be included in the model. true
normalizeData bool If true, data will be normalized before fitting the model. true

As an alternative to passing hyperparameters, each hyperparameter can be set with a standalone method. The following functions can be used before calling Train() to set hyperparameters:

Notes:

🔗 Training

If training is not done as part of the constructor call, it can be done with the Train() function:



Types of each argument are the same as in the table for constructors above.

Notes:

🔗 Prediction

Once a LARS model is trained, the Predict() member function can be used to make predictions for new data.



Prediction Parameters:

usage name type description
single-point point arma::vec Single point for prediction.
       
multi-point data arma::mat Set of column-major points for classification.
multi-point predictions arma::rowvec& Vector of doubles to store predictions into. Will be set to length data.n_cols.
multi-point colMajor bool Should be set to true if data is column-major. Passing row-major data can avoid a transpose operation. (Default true.)

🔗 Other Functionality

🔗 The LARS Path

LARS is a path (or stepwise) algorithm, meaning it adds one feature at a time to the model. This in turn means that when we train a LARS model with lambda1 set to l, we also recover every possible LARS model on the same data with a lambda1 greater than l.

The LARS class provides a way to access all of the models on the path, and switch between them for prediction purposes:

🔗 Simple Examples

See also the simple usage example for a trivial usage of the LARS class.


Train a LARS model in the constructor, and print the MSE on training and test data for each set of weights in the path.

// See https://datasets.mlpack.org/wave_energy_farm_100.csv.
arma::mat data;
mlpack::data::Load("wave_energy_farm_100.csv", data, true);

// Split the last row off: it is the responses.  Also, normalize the responses
// to [0, 1].
arma::rowvec responses = data.row(data.n_rows - 1);
responses /= responses.max();
data.shed_row(data.n_rows - 1);

// Split into a training and test dataset.  20% of the data is held out as a
// test set.
arma::mat trainingData, testData;
arma::rowvec trainingResponses, testResponses;
mlpack::data::Split(data, responses, trainingData, testData, trainingResponses,
    testResponses, 0.2);

// Train a LARS model with lambda1 = 1e-5 and lambda2 = 1e-6.
mlpack::LARS lars(trainingData, trainingResponses, true, true, 1e-5, 1e-6);

// Iterate over all the models in the path.
const size_t pathLength = lars.BetaPath().size();
for (size_t i = 0; i < pathLength; ++i)
{
  // Use the i'th model in the path.
  lars.SelectBeta(lars.LambdaPath()[i]);

  // ComputeError() returns the total loss, which we need to divide by the
  // number of points to get the MSE.
  const double trainMSE = lars.ComputeError(trainingData, trainingResponses) /
      trainingData.n_cols;
  const double testMSE = lars.ComputeError(testData, testResponses) /
      testData.n_cols;
  std::cout << "L1 penalty parameter: " << lars.SelectedLambda1() << std::endl;
  std::cout << "  MSE on training set: " << trainMSE << "." << std::endl;
  std::cout << "  MSE on test set:     " << testMSE << "." << std::endl;
}

Train a LARS model, print predictions for a random point, and save to a file.

// See https://datasets.mlpack.org/admission_predict.csv.
arma::mat data;
mlpack::data::Load("admission_predict.csv", data, true); 

// See https://datasets.mlpack.org/admission_predict.responses.csv.
arma::rowvec responses;
mlpack::data::Load("admission_predict.responses.csv", responses, true);

// Train a LARS model with only L2 regularization.
mlpack::LARS lars(data, responses, true, true, 0.0, 0.1 /* lambda2 */);

// Predict on a random point.
arma::vec point = arma::randu<arma::vec>(data.n_rows);
const double prediction = lars.Predict(point);

std::cout << "Prediction on random point: " << prediction << "." << std::endl;

// Save the model to "lars_model.bin" with the name "lars".
mlpack::data::Save("lars_model.bin", "lars", lars, true);

Load a LARS model from disk and print some information about it.

// This assumes a model named "lars" has previously been saved to
// "lars_model.bin".
mlpack::LARS lars;
mlpack::data::Load("lars_model.bin", "lars", lars, true);

if (lars.BetaPath().size() == 0)
{
  std::cout << "lars_model.bin contains an untrained LARS model." << std::endl;
}
else
{
  std::cout << "Information on the LARS model in lars_model.bin:" << std::endl;

  std::cout << " - Model dimensionality: " << lars.Beta().n_elem << "."
      << std::endl;
  std::cout << " - Has intercept: "
      << (lars.FitIntercept() ? std::string("yes") : std::string("no")) << "."
      << std::endl;
  std::cout << " - Current L1 regularization penalty parameter value: "
      << lars.SelectedLambda1() << "." << std::endl;
  std::cout << " - L2 regularization penalty parameter: " << lars.Lambda2()
      << "." << std::endl;
  std::cout << " - Number of nonzero elements in model: "
      << lars.ActiveSet().size() << "." << std::endl;
  std::cout << " - Number of models in LARS path: " << lars.BetaPath().size()
      << "." << std::endl;
  std::cout << " - Model weight for dimension 0: " << lars.Beta()[0] << "."
      << std::endl;

  if (lars.FitIntercept())
  {
    std::cout << " - Intercept value: " << lars.Intercept() << "." << std::endl;
  }
}

Train several models with different L2 regularization penalty parameters, using a precomputed Gram matrix.

// See https://datasets.mlpack.org/admission_predict.csv.
arma::mat data;
mlpack::data::Load("admission_predict.csv", data, true);

// See https://datasets.mlpack.org/admission_predict.responses.csv.
arma::rowvec responses;
mlpack::data::Load("admission_predict.responses.csv", responses, true);

// Precompute Gram matrix.
arma::mat gramMatrix = data * data.t();

std::vector<double> lambda2Values = { 0.01, 0.1, 1.0, 10.0, 100.0 };
for (double lambda2 : lambda2Values)
{
  // Build a LARS model using the precomputed Gram matrix.  We did not normalize
  // or center the data before computing the Gram matrix, so we have to set
  // fitIntercept and normalizeData accordingly.
  mlpack::LARS lars(data, responses, true, true, gramMatrix, 0.01, lambda2,
      1e-16, false, false);

  std::cout << "MSE with L2 penalty " << lambda2 << ": "
      << (lars.ComputeError(data, responses) / data.n_cols) << "." << std::endl;
}

🔗 Advanced Functionality: Different Element Types

The LARS class has one template parameter that can be used to control the element type of the model. The full signature of the class is:

LARS<ModelMatType>

ModelMatType specifies the type of matrix used for the internal representation of model parameters. Any matrix type that implements the Armadillo API can be used.

Note that the Train() and Predict() functions themselves are templatized and can allow any matrix type that has the same element type. So, for instance, a LARS<arma::sp_mat> can accept an arma::mat for training.

The example below trains a LARS model on 32-bit precision data, using arma::sp_fmat to store the model parameters.

// Create random, sparse 1000-dimensional data.
arma::fmat dataset(1000, 5000, arma::fill::randu);

// Generate noisy responses from random data.
arma::fvec trueWeights(1000, arma::fill::randu);
arma::frowvec responses = trueWeights.t() * dataset +
    0.01 * arma::randu<arma::frowvec>(5000) /* noise term */;

mlpack::LARS<arma::sp_fmat> lars;
lars.Lambda1() = 0.1;
lars.Lambda2() = 0.01;

lars.Train(dataset, responses);

// Compute the MSE on the training set and a random test set.
arma::fmat testDataset(1000, 2500, arma::fill::randu);
arma::frowvec testResponses = trueWeights.t() * testDataset +
    0.01 * arma::randu<arma::frowvec>(2500) /* noise term */;

const float trainMSE = lars.ComputeError(dataset, responses) / dataset.n_cols;
const float testMSE = lars.ComputeError(testDataset, testResponses) /
    testDataset.n_cols;

std::cout << "MSE on training set: " << trainMSE << "." << std::endl;
std::cout << "MSE on test set:     " << testMSE << "." << std::endl;

Note: it is generally only more efficient to use a sparse type (e.g. arma::sp_mat) for ModelMatType when the L1 regularization parameter is set such that a highly sparse model is produced.