linear_svm.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
13 #define MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 #include <ensmallen.hpp>
17 
18 #include "linear_svm_function.hpp"
19 
20 namespace mlpack {
21 namespace svm {
22 
79 template <typename MatType = arma::mat>
80 class LinearSVM
81 {
82  public:
101  template <typename OptimizerType, typename... CallbackTypes>
102  LinearSVM(const MatType& data,
103  const arma::Row<size_t>& labels,
104  const size_t numClasses,
105  const double lambda,
106  const double delta,
107  const bool fitIntercept,
108  OptimizerType optimizer,
109  CallbackTypes&&... callbacks);
110 
126  template <typename OptimizerType = ens::L_BFGS>
127  LinearSVM(const MatType& data,
128  const arma::Row<size_t>& labels,
129  const size_t numClasses = 2,
130  const double lambda = 0.0001,
131  const double delta = 1.0,
132  const bool fitIntercept = false,
133  OptimizerType optimizer = OptimizerType());
134 
146  LinearSVM(const size_t inputSize,
147  const size_t numClasses = 0,
148  const double lambda = 0.0001,
149  const double delta = 1.0,
150  const bool fitIntercept = false);
161  LinearSVM(const size_t numClasses = 0,
162  const double lambda = 0.0001,
163  const double delta = 1.0,
164  const bool fitIntercept = false);
165 
175  void Classify(const MatType& data,
176  arma::Row<size_t>& labels) const;
177 
189  void Classify(const MatType& data,
190  arma::Row<size_t>& labels,
191  arma::mat& scores) const;
192 
199  void Classify(const MatType& data,
200  arma::mat& scores) const;
201 
210  template<typename VecType>
211  size_t Classify(const VecType& point) const;
212 
222  double ComputeAccuracy(const MatType& testData,
223  const arma::Row<size_t>& testLabels) const;
224 
238  template <typename OptimizerType, typename... CallbackTypes>
239  double Train(const MatType& data,
240  const arma::Row<size_t>& labels,
241  const size_t numClasses,
242  OptimizerType optimizer,
243  CallbackTypes&&... callbacks);
244 
255  template <typename OptimizerType = ens::L_BFGS>
256  double Train(const MatType& data,
257  const arma::Row<size_t>& labels,
258  const size_t numClasses = 2,
259  OptimizerType optimizer = OptimizerType());
260 
261 
263  size_t& NumClasses() { return numClasses; }
265  size_t NumClasses() const { return numClasses; }
266 
268  double& Lambda() { return lambda; }
270  double Lambda() const { return lambda; }
271 
273  double& Delta() { return delta; }
275  double Delta() const { return delta; }
276 
278  bool& FitIntercept() { return fitIntercept; }
279 
281  arma::mat& Parameters() { return parameters; }
283  const arma::mat& Parameters() const { return parameters; }
284 
286  size_t FeatureSize() const
287  { return fitIntercept ? parameters.n_rows - 1 :
288  parameters.n_rows; }
289 
293  template<typename Archive>
294  void serialize(Archive& ar, const uint32_t /* version */)
295  {
296  ar(CEREAL_NVP(parameters));
297  ar(CEREAL_NVP(numClasses));
298  ar(CEREAL_NVP(lambda));
299  ar(CEREAL_NVP(fitIntercept));
300  }
301 
302  private:
304  arma::mat parameters;
306  size_t numClasses;
308  double lambda;
310  double delta;
312  bool fitIntercept;
313 };
314 
315 } // namespace svm
316 } // namespace mlpack
317 
318 // Include implementation.
319 #include "linear_svm_impl.hpp"
320 
321 #endif // MLPACK_METHODS_LINEAR_SVM_LINEAR_SVM_HPP
bool & FitIntercept()
Sets the intercept term flag.
Definition: linear_svm.hpp:278
Linear algebra utility functions, generally performed on matrices or vectors.
double Delta() const
Gets the margin between the correct class and all other classes.
Definition: linear_svm.hpp:275
arma::mat & Parameters()
Set the model parameters.
Definition: linear_svm.hpp:281
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
double & Lambda()
Sets the regularization parameter.
Definition: linear_svm.hpp:268
size_t & NumClasses()
Sets the number of classes.
Definition: linear_svm.hpp:263
double & Delta()
Sets the margin between the correct class and all other classes.
Definition: linear_svm.hpp:273
size_t FeatureSize() const
Gets the features size of the training data.
Definition: linear_svm.hpp:286
LinearSVM(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, const double lambda, const double delta, const bool fitIntercept, OptimizerType optimizer, CallbackTypes &&... callbacks)
Construct the LinearSVM class with the provided data and labels.
const arma::mat & Parameters() const
Get the model parameters.
Definition: linear_svm.hpp:283
void Classify(const MatType &data, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.
constexpr auto data(Container const &container) noexcept -> decltype(container.data())
Definition: iterator.hpp:79
void serialize(Archive &ar, const uint32_t)
Serialize the LinearSVM model.
Definition: linear_svm.hpp:294
The LinearSVM class implements an L2-regularized support vector machine model, and supports training ...
Definition: linear_svm.hpp:80
double Train(const MatType &data, const arma::Row< size_t > &labels, const size_t numClasses, OptimizerType optimizer, CallbackTypes &&... callbacks)
Train the Linear SVM with the given training data.
double ComputeAccuracy(const MatType &testData, const arma::Row< size_t > &testLabels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
double Lambda() const
Gets the regularization parameter.
Definition: linear_svm.hpp:270
size_t NumClasses() const
Gets the number of classes.
Definition: linear_svm.hpp:265