shuffle_data.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
13 #define MLPACK_CORE_MATH_SHUFFLE_DATA_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace math {
19 
27 template<typename MatType, typename LabelsType>
28 void ShuffleData(const MatType& inputPoints,
29  const LabelsType& inputLabels,
30  MatType& outputPoints,
31  LabelsType& outputLabels,
32  const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
33  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
34 {
35  // Generate ordering.
36  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
37  inputPoints.n_cols - 1, inputPoints.n_cols));
38 
39  outputPoints = inputPoints.cols(ordering);
40  outputLabels = inputLabels.cols(ordering);
41 }
42 
50 template<typename MatType, typename LabelsType>
51 void ShuffleData(const MatType& inputPoints,
52  const LabelsType& inputLabels,
53  MatType& outputPoints,
54  LabelsType& outputLabels,
55  const std::enable_if_t<arma::is_SpMat<MatType>::value>* = 0,
56  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
57 {
58  // Generate ordering.
59  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
60  inputPoints.n_cols - 1, inputPoints.n_cols));
61 
62  // Extract coordinate list representation.
63  arma::umat locations(2, inputPoints.n_nonzero);
64  arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
65  typename MatType::const_iterator it = inputPoints.begin();
66  size_t index = 0;
67  while (it != inputPoints.end())
68  {
69  locations(0, index) = it.row();
70  locations(1, index) = ordering[it.col()];
71  values(index) = (*it);
72  ++it;
73  ++index;
74  }
75 
76  if (&inputPoints == &outputPoints || &inputLabels == &outputLabels)
77  {
78  MatType newOutputPoints(locations, values, inputPoints.n_rows,
79  inputPoints.n_cols, true);
80  LabelsType newOutputLabels(inputLabels.n_elem);
81  newOutputLabels.cols(ordering) = inputLabels;
82 
83  outputPoints = std::move(newOutputPoints);
84  outputLabels = std::move(newOutputLabels);
85  }
86  else
87  {
88  outputPoints = MatType(locations, values, inputPoints.n_rows,
89  inputPoints.n_cols, true);
90  outputLabels.set_size(inputLabels.n_elem);
91  outputLabels.cols(ordering) = inputLabels;
92  }
93 }
94 
102 template<typename MatType, typename LabelsType>
103 void ShuffleData(const MatType& inputPoints,
104  const LabelsType& inputLabels,
105  MatType& outputPoints,
106  LabelsType& outputLabels,
107  const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
108  const std::enable_if_t<arma::is_Cube<MatType>::value>* = 0,
109  const std::enable_if_t<arma::is_Cube<LabelsType>::value>* = 0)
110 {
111  // Generate ordering.
112  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
113  inputPoints.n_cols - 1, inputPoints.n_cols));
114 
115  // Properly handle the case where the input and output data are the same
116  // object.
117  MatType* outputPointsPtr = &outputPoints;
118  LabelsType* outputLabelsPtr = &outputLabels;
119  if (&inputPoints == &outputPoints)
120  outputPointsPtr = new MatType();
121  if (&inputLabels == &outputLabels)
122  outputLabelsPtr = new LabelsType();
123 
124  outputPointsPtr->set_size(inputPoints.n_rows, inputPoints.n_cols,
125  inputPoints.n_slices);
126  outputLabelsPtr->set_size(inputLabels.n_rows, inputLabels.n_cols,
127  inputLabels.n_slices);
128  for (size_t i = 0; i < ordering.n_elem; ++i)
129  {
130  outputPointsPtr->tube(0, ordering[i], outputPointsPtr->n_rows - 1,
131  ordering[i]) = inputPoints.tube(0, i, inputPoints.n_rows - 1, i);
132  outputLabelsPtr->tube(0, ordering[i], outputLabelsPtr->n_rows - 1,
133  ordering[i]) = inputLabels.tube(0, i, inputLabels.n_rows - 1, i);
134  }
135 
136  // Clean up memory if needed.
137  if (&inputPoints == &outputPoints)
138  {
139  outputPoints = std::move(*outputPointsPtr);
140  delete outputPointsPtr;
141  }
142 
143  if (&inputLabels == &outputLabels)
144  {
145  outputLabels = std::move(*outputLabelsPtr);
146  delete outputLabelsPtr;
147  }
148 }
149 
159 template<typename MatType, typename LabelsType, typename WeightsType>
160 void ShuffleData(const MatType& inputPoints,
161  const LabelsType& inputLabels,
162  const WeightsType& inputWeights,
163  MatType& outputPoints,
164  LabelsType& outputLabels,
165  WeightsType& outputWeights,
166  const std::enable_if_t<!arma::is_SpMat<MatType>::value>* = 0,
167  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
168 {
169  // Generate ordering.
170  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
171  inputPoints.n_cols - 1, inputPoints.n_cols));
172 
173  outputPoints = inputPoints.cols(ordering);
174  outputLabels = inputLabels.cols(ordering);
175  outputWeights = inputWeights.cols(ordering);
176 }
177 
187 template<typename MatType, typename LabelsType, typename WeightsType>
188 void ShuffleData(const MatType& inputPoints,
189  const LabelsType& inputLabels,
190  const WeightsType& inputWeights,
191  MatType& outputPoints,
192  LabelsType& outputLabels,
193  WeightsType& outputWeights,
194  const std::enable_if_t<arma::is_SpMat<MatType>::value>* = 0,
195  const std::enable_if_t<!arma::is_Cube<MatType>::value>* = 0)
196 {
197  // Generate ordering.
198  arma::uvec ordering = arma::shuffle(arma::linspace<arma::uvec>(0,
199  inputPoints.n_cols - 1, inputPoints.n_cols));
200 
201  // Extract coordinate list representation.
202  arma::umat locations(2, inputPoints.n_nonzero);
203  arma::Col<typename MatType::elem_type> values(inputPoints.n_nonzero);
204  typename MatType::const_iterator it = inputPoints.begin();
205  size_t index = 0;
206  while (it != inputPoints.end())
207  {
208  locations(0, index) = it.row();
209  locations(1, index) = ordering[it.col()];
210  values(index) = (*it);
211  ++it;
212  ++index;
213  }
214 
215  if (&inputPoints == &outputPoints || &inputLabels == &outputLabels ||
216  &inputWeights == &outputWeights)
217  {
218  MatType newOutputPoints(locations, values, inputPoints.n_rows,
219  inputPoints.n_cols, true);
220  LabelsType newOutputLabels(inputLabels.n_elem);
221  WeightsType newOutputWeights(inputWeights.n_elem);
222  newOutputLabels.cols(ordering) = inputLabels;
223  newOutputWeights.cols(ordering) = inputWeights;
224 
225  outputPoints = std::move(newOutputPoints);
226  outputLabels = std::move(newOutputLabels);
227  outputWeights = std::move(newOutputWeights);
228  }
229  else
230  {
231  outputPoints = MatType(locations, values, inputPoints.n_rows,
232  inputPoints.n_cols, true);
233  outputLabels.set_size(inputLabels.n_elem);
234  outputLabels.cols(ordering) = inputLabels;
235  outputWeights.set_size(inputWeights.n_elem);
236  outputWeights.cols(ordering) = inputWeights;
237  }
238 }
239 
240 } // namespace math
241 } // namespace mlpack
242 
243 #endif
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
void ShuffleData(const MatType &inputPoints, const LabelsType &inputLabels, MatType &outputPoints, LabelsType &outputLabels, const std::enable_if_t<!arma::is_SpMat< MatType >::value > *=0, const std::enable_if_t<!arma::is_Cube< MatType >::value > *=0)
Shuffle a dataset and associated labels (or responses).