ensmallen
mlpack
fast, flexible C++ machine learning library
test_function_tools.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_TESTS_TEST_FUNCTION_TOOLS_HPP
13 #define MLPACK_TESTS_TEST_FUNCTION_TOOLS_HPP
14 
15 #include <mlpack/core.hpp>
16 
18 
19 using namespace mlpack;
20 using namespace mlpack::distribution;
21 using namespace mlpack::regression;
22 
33 inline void LogisticRegressionTestData(arma::mat& data,
34  arma::mat& testData,
35  arma::mat& shuffledData,
36  arma::Row<size_t>& responses,
37  arma::Row<size_t>& testResponses,
38  arma::Row<size_t>& shuffledResponses)
39 {
40  // Generate a two-Gaussian dataset.
41  GaussianDistribution g1(arma::vec("1.0 1.0 1.0"), arma::eye<arma::mat>(3, 3));
42  GaussianDistribution g2(arma::vec("9.0 9.0 9.0"), arma::eye<arma::mat>(3, 3));
43 
44  data = arma::mat(3, 1000);
45  responses = arma::Row<size_t>(1000);
46  for (size_t i = 0; i < 500; ++i)
47  {
48  data.col(i) = g1.Random();
49  responses[i] = 0;
50  }
51  for (size_t i = 500; i < 1000; ++i)
52  {
53  data.col(i) = g2.Random();
54  responses[i] = 1;
55  }
56 
57  // Shuffle the dataset.
58  arma::uvec indices = arma::shuffle(arma::linspace<arma::uvec>(0,
59  data.n_cols - 1, data.n_cols));
60  shuffledData = arma::mat(3, 1000);
61  shuffledResponses = arma::Row<size_t>(1000);
62  for (size_t i = 0; i < data.n_cols; ++i)
63  {
64  shuffledData.col(i) = data.col(indices[i]);
65  shuffledResponses[i] = responses[indices[i]];
66  }
67 
68  // Create a test set.
69  testData = arma::mat(3, 1000);
70  testResponses = arma::Row<size_t>(1000);
71  for (size_t i = 0; i < 500; ++i)
72  {
73  testData.col(i) = g1.Random();
74  testResponses[i] = 0;
75  }
76  for (size_t i = 500; i < 1000; ++i)
77  {
78  testData.col(i) = g2.Random();
79  testResponses[i] = 1;
80  }
81 }
82 
83 #endif
A single multivariate Gaussian distribution.
.hpp
Definition: add_to_po.hpp:21
Probability distributions.
void LogisticRegressionTestData(arma::mat &data, arma::mat &testData, arma::mat &shuffledData, arma::Row< size_t > &responses, arma::Row< size_t > &testResponses, arma::Row< size_t > &shuffledResponses)
Create the data for the a logistic regression test.
arma::vec Random() const
Return a randomly generated observation according to the probability distribution defined by this obj...
Include all of the base components required to write mlpack methods, and the main mlpack Doxygen docu...
Regression methods.
Definition: lars.hpp:30