all_categorical_split.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
14 #define MLPACK_METHODS_DECISION_TREE_ALL_CATEGORICAL_SPLIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace tree {
20 
27 template<typename FitnessFunction>
29 {
30  public:
31  // No extra info needed for split.
32  template<typename ElemType>
33  class AuxiliarySplitInfo { };
34 
55  template<bool UseWeights, typename VecType, typename WeightVecType>
56  static double SplitIfBetter(
57  const double bestGain,
58  const VecType& data,
59  const size_t numCategories,
60  const arma::Row<size_t>& labels,
61  const size_t numClasses,
62  const WeightVecType& weights,
63  const size_t minimumLeafSize,
64  const double minimumGainSplit,
65  arma::Col<typename VecType::elem_type>& classProbabilities,
67 
74  template<typename ElemType>
75  static size_t NumChildren(const arma::Col<ElemType>& classProbabilities,
76  const AuxiliarySplitInfo<ElemType>& /* aux */);
77 
84  template<typename ElemType>
85  static size_t CalculateDirection(
86  const ElemType& point,
87  const arma::Col<ElemType>& classProbabilities,
88  const AuxiliarySplitInfo<ElemType>& /* aux */);
89 };
90 
91 } // namespace tree
92 } // namespace mlpack
93 
94 // Include implementation.
95 #include "all_categorical_split_impl.hpp"
96 
97 #endif
98 
static size_t NumChildren(const arma::Col< ElemType > &classProbabilities, const AuxiliarySplitInfo< ElemType > &)
Return the number of children in the split.
.hpp
Definition: add_to_po.hpp:21
The core includes that mlpack expects; standard C++ includes and Armadillo.
The AllCategoricalSplit is a splitting function that will split categorical features into many childr...
static double SplitIfBetter(const double bestGain, const VecType &data, const size_t numCategories, const arma::Row< size_t > &labels, const size_t numClasses, const WeightVecType &weights, const size_t minimumLeafSize, const double minimumGainSplit, arma::Col< typename VecType::elem_type > &classProbabilities, AuxiliarySplitInfo< typename VecType::elem_type > &aux)
Check if we can split a node.
static size_t CalculateDirection(const ElemType &point, const arma::Col< ElemType > &classProbabilities, const AuxiliarySplitInfo< ElemType > &)
Calculate the direction a point should percolate to.