decision_tree.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
14 #define MLPACK_METHODS_DECISION_TREE_DECISION_TREE_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "gini_gain.hpp"
20 #include "all_dimension_select.hpp"
21 #include <type_traits>
22 
23 namespace mlpack {
24 namespace tree {
25 
33 template<typename FitnessFunction = GiniGain,
34  template<typename> class NumericSplitType = BestBinaryNumericSplit,
35  template<typename> class CategoricalSplitType = AllCategoricalSplit,
36  typename DimensionSelectionType = AllDimensionSelect,
37  typename ElemType = double,
38  bool NoRecursion = false>
39 class DecisionTree :
40  public NumericSplitType<FitnessFunction>::template
41  AuxiliarySplitInfo<ElemType>,
42  public CategoricalSplitType<FitnessFunction>::template
43  AuxiliarySplitInfo<ElemType>
44 {
45  public:
47  typedef NumericSplitType<FitnessFunction> NumericSplit;
49  typedef CategoricalSplitType<FitnessFunction> CategoricalSplit;
51  typedef DimensionSelectionType DimensionSelection;
52 
70  template<typename MatType, typename LabelsType>
71  DecisionTree(MatType data,
72  const data::DatasetInfo& datasetInfo,
73  LabelsType labels,
74  const size_t numClasses,
75  const size_t minimumLeafSize = 10,
76  const double minimumGainSplit = 1e-7,
77  const size_t maximumDepth = 0,
78  DimensionSelectionType dimensionSelector =
79  DimensionSelectionType());
80 
97  template<typename MatType, typename LabelsType>
98  DecisionTree(MatType data,
99  LabelsType labels,
100  const size_t numClasses,
101  const size_t minimumLeafSize = 10,
102  const double minimumGainSplit = 1e-7,
103  const size_t maximumDepth = 0,
104  DimensionSelectionType dimensionSelector =
105  DimensionSelectionType());
106 
126  template<typename MatType, typename LabelsType, typename WeightsType>
127  DecisionTree(MatType data,
128  const data::DatasetInfo& datasetInfo,
129  LabelsType labels,
130  const size_t numClasses,
131  WeightsType weights,
132  const size_t minimumLeafSize = 10,
133  const double minimumGainSplit = 1e-7,
134  const size_t maximumDepth = 0,
135  DimensionSelectionType dimensionSelector =
136  DimensionSelectionType(),
137  const std::enable_if_t<arma::is_arma_type<
138  typename std::remove_reference<WeightsType>::type>::value>*
139  = 0);
140 
159  template<typename MatType, typename LabelsType, typename WeightsType>
160  DecisionTree(MatType data,
161  LabelsType labels,
162  const size_t numClasses,
163  WeightsType weights,
164  const size_t minimumLeafSize = 10,
165  const double minimumGainSplit = 1e-7,
166  const size_t maximumDepth = 0,
167  DimensionSelectionType dimensionSelector =
168  DimensionSelectionType(),
169  const std::enable_if_t<arma::is_arma_type<
170  typename std::remove_reference<WeightsType>::type>::value>*
171  = 0);
172 
173 
180  DecisionTree(const size_t numClasses = 1);
181 
188  DecisionTree(const DecisionTree& other);
189 
195  DecisionTree(DecisionTree&& other);
196 
203  DecisionTree& operator=(const DecisionTree& other);
204 
211 
215  ~DecisionTree();
216 
237  template<typename MatType, typename LabelsType>
238  double Train(MatType data,
239  const data::DatasetInfo& datasetInfo,
240  LabelsType labels,
241  const size_t numClasses,
242  const size_t minimumLeafSize = 10,
243  const double minimumGainSplit = 1e-7,
244  const size_t maximumDepth = 0,
245  DimensionSelectionType dimensionSelector =
246  DimensionSelectionType());
247 
266  template<typename MatType, typename LabelsType>
267  double Train(MatType data,
268  LabelsType labels,
269  const size_t numClasses,
270  const size_t minimumLeafSize = 10,
271  const double minimumGainSplit = 1e-7,
272  const size_t maximumDepth = 0,
273  DimensionSelectionType dimensionSelector =
274  DimensionSelectionType());
275 
297  template<typename MatType, typename LabelsType, typename WeightsType>
298  double Train(MatType data,
299  const data::DatasetInfo& datasetInfo,
300  LabelsType labels,
301  const size_t numClasses,
302  WeightsType weights,
303  const size_t minimumLeafSize = 10,
304  const double minimumGainSplit = 1e-7,
305  const size_t maximumDepth = 0,
306  DimensionSelectionType dimensionSelector =
307  DimensionSelectionType(),
308  const std::enable_if_t<arma::is_arma_type<typename
309  std::remove_reference<WeightsType>::type>::value>* = 0);
310 
330  template<typename MatType, typename LabelsType, typename WeightsType>
331  double Train(MatType data,
332  LabelsType labels,
333  const size_t numClasses,
334  WeightsType weights,
335  const size_t minimumLeafSize = 10,
336  const double minimumGainSplit = 1e-7,
337  const size_t maximumDepth = 0,
338  DimensionSelectionType dimensionSelector =
339  DimensionSelectionType(),
340  const std::enable_if_t<arma::is_arma_type<typename
341  std::remove_reference<WeightsType>::type>::value>* = 0);
342 
349  template<typename VecType>
350  size_t Classify(const VecType& point) const;
351 
361  template<typename VecType>
362  void Classify(const VecType& point,
363  size_t& prediction,
364  arma::vec& probabilities) const;
365 
373  template<typename MatType>
374  void Classify(const MatType& data,
375  arma::Row<size_t>& predictions) const;
376 
387  template<typename MatType>
388  void Classify(const MatType& data,
389  arma::Row<size_t>& predictions,
390  arma::mat& probabilities) const;
391 
395  template<typename Archive>
396  void serialize(Archive& ar, const unsigned int /* version */);
397 
399  size_t NumChildren() const { return children.size(); }
400 
402  const DecisionTree& Child(const size_t i) const { return *children[i]; }
404  DecisionTree& Child(const size_t i) { return *children[i]; }
405 
408  size_t SplitDimension() const { return splitDimension; }
409 
417  template<typename VecType>
418  size_t CalculateDirection(const VecType& point) const;
419 
423  size_t NumClasses() const;
424 
425  private:
427  std::vector<DecisionTree*> children;
429  size_t splitDimension;
432  size_t dimensionTypeOrMajorityClass;
440  arma::vec classProbabilities;
441 
445  typedef typename NumericSplit::template AuxiliarySplitInfo<ElemType>
446  NumericAuxiliarySplitInfo;
447  typedef typename CategoricalSplit::template AuxiliarySplitInfo<ElemType>
448  CategoricalAuxiliarySplitInfo;
449 
453  template<bool UseWeights, typename RowType, typename WeightsRowType>
454  void CalculateClassProbabilities(const RowType& labels,
455  const size_t numClasses,
456  const WeightsRowType& weights);
457 
475  template<bool UseWeights, typename MatType>
476  double Train(MatType& data,
477  const size_t begin,
478  const size_t count,
479  const data::DatasetInfo& datasetInfo,
480  arma::Row<size_t>& labels,
481  const size_t numClasses,
482  arma::rowvec& weights,
483  const size_t minimumLeafSize,
484  const double minimumGainSplit,
485  const size_t maximumDepth,
486  DimensionSelectionType& dimensionSelector);
487 
504  template<bool UseWeights, typename MatType>
505  double Train(MatType& data,
506  const size_t begin,
507  const size_t count,
508  arma::Row<size_t>& labels,
509  const size_t numClasses,
510  arma::rowvec& weights,
511  const size_t minimumLeafSize,
512  const double minimumGainSplit,
513  const size_t maximumDepth,
514  DimensionSelectionType& dimensionSelector);
515 };
516 
520 template<typename FitnessFunction = GiniGain,
521  template<typename> class NumericSplitType = BestBinaryNumericSplit,
522  template<typename> class CategoricalSplitType = AllCategoricalSplit,
523  typename DimensionSelectType = AllDimensionSelect,
524  typename ElemType = double>
525 using DecisionStump = DecisionTree<FitnessFunction,
526  NumericSplitType,
527  CategoricalSplitType,
528  DimensionSelectType,
529  ElemType,
530  false>;
531 
532 } // namespace tree
533 } // namespace mlpack
534 
535 // Include implementation.
536 #include "decision_tree_impl.hpp"
537 
538 #endif
size_t NumChildren() const
Get the number of children.
Auxiliary information for a dataset, including mappings to/from strings (or other types) and the data...
DecisionTree(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Construct the decision tree on the given data and labels, where the data can be both numeric and cate...
The BestBinaryNumericSplit is a splitting function for decision trees that will exhaustively search a...
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
.hpp
Definition: add_to_po.hpp:21
This class implements a generic decision tree learner.
size_t CalculateDirection(const VecType &point) const
Given a point and that this node is not a leaf, calculate the index of the child node this point woul...
The core includes that mlpack expects; standard C++ includes and Armadillo.
NumericSplitType< FitnessFunction > NumericSplit
Allow access to the numeric split type.
const DecisionTree & Child(const size_t i) const
Get the child of the given index.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
CategoricalSplitType< FitnessFunction > CategoricalSplit
Allow access to the categorical split type.
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
DecisionTree & operator=(const DecisionTree &other)
Copy another tree.
DecisionTree & Child(const size_t i)
Modify the child of the given index (be careful!).
DimensionSelectionType DimensionSelection
Allow access to the dimension selection type.
size_t NumClasses() const
Get the number of classes in the tree.
This dimension selection policy allows any dimension to be selected for splitting.
The Gini gain, a measure of set purity usable as a fitness function (FitnessFunction) for decision tr...
Definition: gini_gain.hpp:27
size_t Classify(const VecType &point) const
Classify the given point, using the entire tree.
size_t SplitDimension() const
Get the split dimension (only meaningful if this is a non-leaf in a trained tree).
double Train(MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, const size_t maximumDepth=0, DimensionSelectionType dimensionSelector=DimensionSelectionType())
Train the decision tree on the given data.
~DecisionTree()
Clean up memory.