rp_tree_mean_split.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
14 #define MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "rp_tree_max_split.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
32 template<typename BoundType, typename MatType = arma::mat>
34 {
35  public:
37  typedef typename MatType::elem_type ElemType;
39  struct SplitInfo
40  {
42  arma::Col<ElemType> direction;
44  arma::Col<ElemType> mean;
46  ElemType splitVal;
49  bool meanSplit;
50  };
51 
64  static bool SplitNode(const BoundType& /* bound */,
65  MatType& data,
66  const size_t begin,
67  const size_t count,
68  SplitInfo& splitInfo);
69 
83  static size_t PerformSplit(MatType& data,
84  const size_t begin,
85  const size_t count,
86  const SplitInfo& splitInfo)
87  {
88  return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
89  splitInfo);
90  }
91 
108  static size_t PerformSplit(MatType& data,
109  const size_t begin,
110  const size_t count,
111  const SplitInfo& splitInfo,
112  std::vector<size_t>& oldFromNew)
113  {
114  return split::PerformSplit<MatType, RPTreeMeanSplit>(data, begin, count,
115  splitInfo, oldFromNew);
116  }
117 
124  template<typename VecType>
125  static bool AssignToLeftNode(const VecType& point, const SplitInfo& splitInfo)
126  {
127  if (splitInfo.meanSplit)
128  return arma::dot(point - splitInfo.mean, point - splitInfo.mean) <=
129  splitInfo.splitVal;
130 
131  return (arma::dot(point, splitInfo.direction) <= splitInfo.splitVal);
132  }
133 
134  private:
141  static ElemType GetAveragePointDistance(MatType& data,
142  const arma::uvec& samples);
143 
153  static bool GetDotMedian(const MatType& data,
154  const arma::uvec& samples,
155  const arma::Col<ElemType>& direction,
156  ElemType& splitVal);
157 
167  static bool GetMeanMedian(const MatType& data,
168  const arma::uvec& samples,
169  arma::Col<ElemType>& mean,
170  ElemType& splitVal);
171 };
172 
173 } // namespace tree
174 } // namespace mlpack
175 
176 // Include implementation.
177 #include "rp_tree_mean_split_impl.hpp"
178 
179 #endif // MLPACK_CORE_TREE_BINARY_SPACE_TREE_RP_TREE_MEAN_SPLIT_HPP
.hpp
Definition: add_to_po.hpp:21
An information about the partition.
The core includes that mlpack expects; standard C++ includes and Armadillo.
ElemType splitVal
The value according to which the split will be performed.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo, std::vector< size_t > &oldFromNew)
Perform the split process according to the information about the split and return the list of changed...
arma::Col< ElemType > mean
The mean of some sampled points.
This class splits a binary space tree.
static bool AssignToLeftNode(const VecType &point, const SplitInfo &splitInfo)
Indicates that a point should be assigned to the left subtree.
static bool SplitNode(const BoundType &, MatType &data, const size_t begin, const size_t count, SplitInfo &splitInfo)
Split the node according to the mean value in the dimension with maximum width.
MatType::elem_type ElemType
The element type held by the matrix type.
bool meanSplit
Indicates that we should use the mean split algorithm instead of the median split.
static size_t PerformSplit(MatType &data, const size_t begin, const size_t count, const SplitInfo &splitInfo)
Perform the split process according to the information about the split.
arma::Col< ElemType > direction
The normal to the hyperplane that will split the node.