ensmallen
mlpack
fast, flexible C++ machine learning library
DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion > Class Template Reference

This class implements a generic decision tree learner. More...

Inheritance diagram for DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion >:

Public Types

typedef CategoricalSplitType< FitnessFunction > CategoricalSplit
 Allow access to the categorical split type. More...

 
typedef DimensionSelectionType DimensionSelection
 Allow access to the dimension selection type. More...

 
typedef NumericSplitType< FitnessFunction > NumericSplit
 Allow access to the numeric split type. More...

 

Public Member Functions

template
<
typename
MatType
,
typename
LabelsType
>
 DecisionTree (MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
 Construct the decision tree on the given data and labels, where the data can be both numeric and categorical. More...

 
template
<
typename
MatType
,
typename
LabelsType
>
 DecisionTree (MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
 Construct the decision tree on the given data and labels, assuming that the data is all of the numeric type. More...

 
template
<
typename
MatType
,
typename
LabelsType
,
typename
WeightsType
>
 DecisionTree (MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
 Construct the decision tree on the given data and labels with weights, where the data can be both numeric and categorical. More...

 
template
<
typename
MatType
,
typename
LabelsType
,
typename
WeightsType
>
 DecisionTree (MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
 Construct the decision tree on the given data and labels with weights, assuming that the data is all of the numeric type. More...

 
 DecisionTree (const size_t numClasses=1)
 Construct a decision tree without training it. More...

 
 DecisionTree (const DecisionTree &other)
 Copy another tree. More...

 
 DecisionTree (DecisionTree &&other)
 Take ownership of another tree. More...

 
 ~DecisionTree ()
 Clean up memory. More...

 
template
<
typename
VecType
>
size_t CalculateDirection (const VecType &point) const
 Given a point and that this node is not a leaf, calculate the index of the child node this point would go towards. More...

 
const DecisionTreeChild (const size_t i) const
 Get the child of the given index. More...

 
DecisionTreeChild (const size_t i)
 Modify the child of the given index (be careful!). More...

 
template
<
typename
VecType
>
size_t Classify (const VecType &point) const
 Classify the given point, using the entire tree. More...

 
template
<
typename
VecType
>
void Classify (const VecType &point, size_t &prediction, arma::vec &probabilities) const
 Classify the given point and also return estimates of the probability for each class in the given vector. More...

 
template
<
typename
MatType
>
void Classify (const MatType &data, arma::Row< size_t > &predictions) const
 Classify the given points, using the entire tree. More...

 
template
<
typename
MatType
>
void Classify (const MatType &data, arma::Row< size_t > &predictions, arma::mat &probabilities) const
 Classify the given points and also return estimates of the probabilities for each class in the given matrix. More...

 
size_t NumChildren () const
 Get the number of children. More...

 
size_t NumClasses () const
 Get the number of classes in the tree. More...

 
DecisionTreeoperator= (const DecisionTree &other)
 Copy another tree. More...

 
DecisionTreeoperator= (DecisionTree &&other)
 Take ownership of another tree. More...

 
template
<
typename
Archive
>
void serialize (Archive &ar, const unsigned int)
 Serialize the tree. More...

 
size_t SplitDimension () const
 Get the split dimension (only meaningful if this is a non-leaf in a trained tree). More...

 
template
<
typename
MatType
,
typename
LabelsType
>
double Train (MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
 Train the decision tree on the given data. More...

 
template
<
typename
MatType
,
typename
LabelsType
>
double Train (MatType data, LabelsType labels, const size_t numClasses, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType())
 Train the decision tree on the given data, assuming that all dimensions are numeric. More...

 
template
<
typename
MatType
,
typename
LabelsType
,
typename
WeightsType
>
double Train (MatType data, const data::DatasetInfo &datasetInfo, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
 Train the decision tree on the given weighted data. More...

 
template
<
typename
MatType
,
typename
LabelsType
,
typename
WeightsType
>
double Train (MatType data, LabelsType labels, const size_t numClasses, WeightsType weights, const size_t minimumLeafSize=10, const double minimumGainSplit=1e-7, DimensionSelectionType dimensionSelector=DimensionSelectionType(), const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *=0)
 Train the decision tree on the given weighted data, assuming that all dimensions are numeric. More...

 

Detailed Description


template<typename FitnessFunction = GiniGain, template< typename > class NumericSplitType = BestBinaryNumericSplit, template< typename > class CategoricalSplitType = AllCategoricalSplit, typename DimensionSelectionType = AllDimensionSelect, typename ElemType = double, bool NoRecursion = false>
class mlpack::tree::DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion >

This class implements a generic decision tree learner.

Its behavior can be controlled via its template arguments.

The class inherits from the auxiliary split information in order to prevent an empty auxiliary split information struct from taking any extra size.

Definition at line 39 of file decision_tree.hpp.

Member Typedef Documentation

◆ CategoricalSplit

typedef CategoricalSplitType<FitnessFunction> CategoricalSplit

Allow access to the categorical split type.

Definition at line 49 of file decision_tree.hpp.

◆ DimensionSelection

typedef DimensionSelectionType DimensionSelection

Allow access to the dimension selection type.

Definition at line 51 of file decision_tree.hpp.

◆ NumericSplit

typedef NumericSplitType<FitnessFunction> NumericSplit

Allow access to the numeric split type.

Definition at line 47 of file decision_tree.hpp.

Constructor & Destructor Documentation

◆ DecisionTree() [1/7]

DecisionTree ( MatType  data,
const data::DatasetInfo datasetInfo,
LabelsType  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType() 
)

Construct the decision tree on the given data and labels, where the data can be both numeric and categorical.

Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data or labels are no longer needed to avoid copies.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension of the dataset.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.

◆ DecisionTree() [2/7]

DecisionTree ( MatType  data,
LabelsType  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType() 
)

Construct the decision tree on the given data and labels, assuming that the data is all of the numeric type.

Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data or labels are no longer needed to avoid copies.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.

◆ DecisionTree() [3/7]

DecisionTree ( MatType  data,
const data::DatasetInfo datasetInfo,
LabelsType  labels,
const size_t  numClasses,
WeightsType  weights,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType(),
const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *  = 0 
)

Construct the decision tree on the given data and labels with weights, where the data can be both numeric and categorical.

Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data, labels or weights are no longer needed to avoid copies.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension of the dataset.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsThe weight list of given label.
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.

◆ DecisionTree() [4/7]

DecisionTree ( MatType  data,
LabelsType  labels,
const size_t  numClasses,
WeightsType  weights,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType(),
const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *  = 0 
)

Construct the decision tree on the given data and labels with weights, assuming that the data is all of the numeric type.

Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data, labels or weights are no longer needed to avoid copies.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsThe Weight list of given labels.
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.

◆ DecisionTree() [5/7]

DecisionTree ( const size_t  numClasses = 1)

Construct a decision tree without training it.

It will be a leaf node with equal probabilities for each class.

Parameters
numClassesNumber of classes in the dataset.

◆ DecisionTree() [6/7]

DecisionTree ( const DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion > &  other)

Copy another tree.

This may use a lot of memory—be sure that it's what you want to do.

Parameters
otherTree to copy.

◆ DecisionTree() [7/7]

DecisionTree ( DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion > &&  other)

Take ownership of another tree.

Parameters
otherTree to take ownership of.

◆ ~DecisionTree()

Clean up memory.

Member Function Documentation

◆ CalculateDirection()

size_t CalculateDirection ( const VecType &  point) const

Given a point and that this node is not a leaf, calculate the index of the child node this point would go towards.

This method is primarily used by the Classify() function, but it can be used in a standalone sense too.

Parameters
pointPoint to classify.

Referenced by DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion >::SplitDimension().

◆ Child() [1/2]

const DecisionTree& Child ( const size_t  i) const
inline

Get the child of the given index.

Definition at line 386 of file decision_tree.hpp.

◆ Child() [2/2]

DecisionTree& Child ( const size_t  i)
inline

Modify the child of the given index (be careful!).

Definition at line 388 of file decision_tree.hpp.

◆ Classify() [1/4]

size_t Classify ( const VecType &  point) const

Classify the given point, using the entire tree.

The predicted label is returned.

Parameters
pointPoint to classify.

◆ Classify() [2/4]

void Classify ( const VecType &  point,
size_t &  prediction,
arma::vec &  probabilities 
) const

Classify the given point and also return estimates of the probability for each class in the given vector.

Parameters
pointPoint to classify.
predictionThis will be set to the predicted class of the point.
probabilitiesThis will be filled with class probabilities for the point.

◆ Classify() [3/4]

void Classify ( const MatType &  data,
arma::Row< size_t > &  predictions 
) const

Classify the given points, using the entire tree.

The predicted labels for each point are stored in the given vector.

Parameters
dataSet of points to classify.
predictionsThis will be filled with predictions for each point.

◆ Classify() [4/4]

void Classify ( const MatType &  data,
arma::Row< size_t > &  predictions,
arma::mat &  probabilities 
) const

Classify the given points and also return estimates of the probabilities for each class in the given matrix.

The predicted labels for each point are stored in the given vector.

Parameters
dataSet of points to classify.
predictionsThis will be filled with predictions for each point.
probabilitiesThis will be filled with class probabilities for each point.

◆ NumChildren()

size_t NumChildren ( ) const
inline

Get the number of children.

Definition at line 383 of file decision_tree.hpp.

◆ NumClasses()

◆ operator=() [1/2]

DecisionTree& operator= ( const DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion > &  other)

Copy another tree.

This may use a lot of memory—be sure that it's what you want to do.

Parameters
otherTree to copy.

◆ operator=() [2/2]

DecisionTree& operator= ( DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion > &&  other)

Take ownership of another tree.

Parameters
otherTree to take ownership of.

◆ serialize()

void serialize ( Archive &  ar,
const unsigned  int 
)

Serialize the tree.

◆ SplitDimension()

◆ Train() [1/4]

double Train ( MatType  data,
const data::DatasetInfo datasetInfo,
LabelsType  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType() 
)

Train the decision tree on the given data.

This will overwrite the existing model. The data may have numeric and categorical types, specified by the datasetInfo parameter. Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data or labels are no longer needed to avoid copies.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsWeights of all the labels
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.
Returns
The final entropy of decision tree.

Referenced by DecisionTree< FitnessFunction, NumericSplitType, CategoricalSplitType, DimensionSelectionType, ElemType, NoRecursion >::SplitDimension().

◆ Train() [2/4]

double Train ( MatType  data,
LabelsType  labels,
const size_t  numClasses,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType() 
)

Train the decision tree on the given data, assuming that all dimensions are numeric.

This will overwrite the given model. Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data or labels are no longer needed to avoid copies.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsWeights of all the labels
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.
Returns
The final entropy of decision tree.

◆ Train() [3/4]

double Train ( MatType  data,
const data::DatasetInfo datasetInfo,
LabelsType  labels,
const size_t  numClasses,
WeightsType  weights,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType(),
const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *  = 0 
)

Train the decision tree on the given weighted data.

This will overwrite the existing model. The data may have numeric and categorical types, specified by the datasetInfo parameter. Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data, labels or weights are no longer needed to avoid copies.

Parameters
dataDataset to train on.
datasetInfoType information for each dimension.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsWeights of all the labels
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.
Returns
The final entropy of decision tree.

◆ Train() [4/4]

double Train ( MatType  data,
LabelsType  labels,
const size_t  numClasses,
WeightsType  weights,
const size_t  minimumLeafSize = 10,
const double  minimumGainSplit = 1e-7,
DimensionSelectionType  dimensionSelector = DimensionSelectionType(),
const std::enable_if_t< arma::is_arma_type< typename std::remove_reference< WeightsType >::type >::value > *  = 0 
)

Train the decision tree on the given weighted data, assuming that all dimensions are numeric.

This will overwrite the given model. Setting minimumLeafSize and minimumGainSplit too small may cause the tree to overfit, but setting them too large may cause it to underfit.

Use std::move if data, labels or weights are no longer needed to avoid copies.

Parameters
dataDataset to train on.
labelsLabels for each training point.
numClassesNumber of classes in the dataset.
weightsWeights of all the labels
minimumLeafSizeMinimum number of points in each leaf node.
minimumGainSplitMinimum gain for the node to split.
dimensionSelectorInstantiated dimension selection policy.
Returns
The final entropy of decision tree.

The documentation for this class was generated from the following file:
  • /home/ryan/src/mlpack.org/_src/mlpack-git/src/mlpack/methods/decision_tree/decision_tree.hpp