cover_tree.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
13 #define MLPACK_CORE_TREE_COVER_TREE_COVER_TREE_HPP
14 
15 #include <mlpack/prereqs.hpp>
17 
18 #include "../statistic.hpp"
19 #include "first_point_is_root.hpp"
20 
21 namespace mlpack {
22 namespace tree {
23 
95 template<typename MetricType = metric::LMetric<2, true>,
96  typename StatisticType = EmptyStatistic,
97  typename MatType = arma::mat,
98  typename RootPointPolicy = FirstPointIsRoot>
99 class CoverTree
100 {
101  public:
103  typedef MatType Mat;
105  typedef typename MatType::elem_type ElemType;
106 
117  CoverTree(const MatType& dataset,
118  const ElemType base = 2.0,
119  MetricType* metric = NULL);
120 
130  CoverTree(const MatType& dataset,
131  MetricType& metric,
132  const ElemType base = 2.0);
133 
141  CoverTree(MatType&& dataset,
142  const ElemType base = 2.0);
143 
152  CoverTree(MatType&& dataset,
153  MetricType& metric,
154  const ElemType base = 2.0);
155 
187  CoverTree(const MatType& dataset,
188  const ElemType base,
189  const size_t pointIndex,
190  const int scale,
191  CoverTree* parent,
192  const ElemType parentDistance,
193  arma::Col<size_t>& indices,
194  arma::vec& distances,
195  size_t nearSetSize,
196  size_t& farSetSize,
197  size_t& usedSetSize,
198  MetricType& metric = NULL);
199 
216  CoverTree(const MatType& dataset,
217  const ElemType base,
218  const size_t pointIndex,
219  const int scale,
220  CoverTree* parent,
221  const ElemType parentDistance,
222  const ElemType furthestDescendantDistance,
223  MetricType* metric = NULL);
224 
231  CoverTree(const CoverTree& other);
232 
239  CoverTree(CoverTree&& other);
240 
244  template<typename Archive>
245  CoverTree(
246  Archive& ar,
248 
252  ~CoverTree();
253 
256  template<typename RuleType>
258 
260  template<typename RuleType>
262 
263  template<typename RuleType>
265 
267  const MatType& Dataset() const { return *dataset; }
268 
270  size_t Point() const { return point; }
272  size_t Point(const size_t) const { return point; }
273 
274  bool IsLeaf() const { return (children.size() == 0); }
275  size_t NumPoints() const { return 1; }
276 
278  const CoverTree& Child(const size_t index) const { return *children[index]; }
280  CoverTree& Child(const size_t index) { return *children[index]; }
281 
282  CoverTree*& ChildPtr(const size_t index) { return children[index]; }
283 
285  size_t NumChildren() const { return children.size(); }
286 
288  const std::vector<CoverTree*>& Children() const { return children; }
290  std::vector<CoverTree*>& Children() { return children; }
291 
293  size_t NumDescendants() const;
294 
296  size_t Descendant(const size_t index) const;
297 
299  int Scale() const { return scale; }
301  int& Scale() { return scale; }
302 
304  ElemType Base() const { return base; }
306  ElemType& Base() { return base; }
307 
309  const StatisticType& Stat() const { return stat; }
311  StatisticType& Stat() { return stat; }
312 
317  template<typename VecType>
318  size_t GetNearestChild(
319  const VecType& point,
321 
326  template<typename VecType>
327  size_t GetFurthestChild(
328  const VecType& point,
330 
335  size_t GetNearestChild(const CoverTree& queryNode);
336 
341  size_t GetFurthestChild(const CoverTree& queryNode);
342 
344  ElemType MinDistance(const CoverTree& other) const;
345 
348  ElemType MinDistance(const CoverTree& other, const ElemType distance) const;
349 
351  ElemType MinDistance(const arma::vec& other) const;
352 
355  ElemType MinDistance(const arma::vec& other, const ElemType distance) const;
356 
358  ElemType MaxDistance(const CoverTree& other) const;
359 
362  ElemType MaxDistance(const CoverTree& other, const ElemType distance) const;
363 
365  ElemType MaxDistance(const arma::vec& other) const;
366 
369  ElemType MaxDistance(const arma::vec& other, const ElemType distance) const;
370 
373 
377  const ElemType distance) const;
378 
380  math::RangeType<ElemType> RangeDistance(const arma::vec& other) const;
381 
384  math::RangeType<ElemType> RangeDistance(const arma::vec& other,
385  const ElemType distance) const;
386 
388  CoverTree* Parent() const { return parent; }
390  CoverTree*& Parent() { return parent; }
391 
393  ElemType ParentDistance() const { return parentDistance; }
395  ElemType& ParentDistance() { return parentDistance; }
396 
398  ElemType FurthestPointDistance() const { return 0.0; }
399 
401  ElemType FurthestDescendantDistance() const
402  { return furthestDescendantDistance; }
405  ElemType& FurthestDescendantDistance() { return furthestDescendantDistance; }
406 
409  ElemType MinimumBoundDistance() const { return furthestDescendantDistance; }
410 
412  void Center(arma::vec& center) const
413  {
414  center = arma::vec(dataset->col(point));
415  }
416 
418  MetricType& Metric() const { return *metric; }
419 
420  private:
422  const MatType* dataset;
424  size_t point;
426  std::vector<CoverTree*> children;
428  int scale;
430  ElemType base;
432  StatisticType stat;
434  size_t numDescendants;
436  CoverTree* parent;
438  ElemType parentDistance;
440  ElemType furthestDescendantDistance;
442  bool localMetric;
444  bool localDataset;
446  MetricType* metric;
447 
451  void CreateChildren(arma::Col<size_t>& indices,
452  arma::vec& distances,
453  size_t nearSetSize,
454  size_t& farSetSize,
455  size_t& usedSetSize);
456 
468  void ComputeDistances(const size_t pointIndex,
469  const arma::Col<size_t>& indices,
470  arma::vec& distances,
471  const size_t pointSetSize);
486  size_t SplitNearFar(arma::Col<size_t>& indices,
487  arma::vec& distances,
488  const ElemType bound,
489  const size_t pointSetSize);
490 
510  size_t SortPointSet(arma::Col<size_t>& indices,
511  arma::vec& distances,
512  const size_t childFarSetSize,
513  const size_t childUsedSetSize,
514  const size_t farSetSize);
515 
516  void MoveToUsedSet(arma::Col<size_t>& indices,
517  arma::vec& distances,
518  size_t& nearSetSize,
519  size_t& farSetSize,
520  size_t& usedSetSize,
521  arma::Col<size_t>& childIndices,
522  const size_t childFarSetSize,
523  const size_t childUsedSetSize);
524  size_t PruneFarSet(arma::Col<size_t>& indices,
525  arma::vec& distances,
526  const ElemType bound,
527  const size_t nearSetSize,
528  const size_t pointSetSize);
529 
534  void RemoveNewImplicitNodes();
535 
536  protected:
543  CoverTree();
544 
546  friend class boost::serialization::access;
547 
548  public:
552  template<typename Archive>
553  void serialize(Archive& ar, const unsigned int /* version */);
554 
555  size_t DistanceComps() const { return distanceComps; }
556  size_t& DistanceComps() { return distanceComps; }
557 
558  private:
559  size_t distanceComps;
560 };
561 
562 } // namespace tree
563 } // namespace mlpack
564 
565 // Include implementation.
566 #include "cover_tree_impl.hpp"
567 
568 // Include the rest of the pieces, if necessary.
569 #include "../cover_tree.hpp"
570 
571 #endif
size_t GetNearestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the nearest child node to the given query point.
size_t DistanceComps() const
Definition: cover_tree.hpp:555
size_t NumPoints() const
Definition: cover_tree.hpp:275
MatType Mat
So that other classes can access the matrix type.
Definition: cover_tree.hpp:103
void Center(arma::vec &center) const
Get the center of the node and store it in the given vector.
Definition: cover_tree.hpp:412
A dual-tree cover tree traverser; see dual_tree_traverser.hpp.
Definition: cover_tree.hpp:261
typename enable_if< B, T >::type enable_if_t
Definition: prereqs.hpp:58
ElemType Base() const
Get the base.
Definition: cover_tree.hpp:304
.hpp
Definition: add_to_po.hpp:21
size_t Point() const
Get the index of the point which this node represents.
Definition: cover_tree.hpp:270
MatType::elem_type ElemType
The type held by the matrix type.
Definition: cover_tree.hpp:105
ElemType MaxDistance(const CoverTree &other) const
Return the maximum distance to another node.
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t GetFurthestChild(const VecType &point, typename std::enable_if_t< IsVector< VecType >::value > *=0)
Return the index of the furthest child node to the given query point.
ElemType & FurthestDescendantDistance()
Modify the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:405
const std::vector< CoverTree * > & Children() const
Get the children.
Definition: cover_tree.hpp:288
int & Scale()
Modify the scale of this node. Be careful...
Definition: cover_tree.hpp:301
StatisticType & Stat()
Modify the statistic for this node.
Definition: cover_tree.hpp:311
CoverTree()
A default constructor.
CoverTree *& Parent()
Modify the parent node.
Definition: cover_tree.hpp:390
CoverTree * Parent() const
Get the parent node.
Definition: cover_tree.hpp:388
std::vector< CoverTree * > & Children()
Modify the children manually (maybe not a great idea).
Definition: cover_tree.hpp:290
int Scale() const
Get the scale of this node.
Definition: cover_tree.hpp:299
~CoverTree()
Delete this cover tree node and its children.
const StatisticType & Stat() const
Get the statistic for this node.
Definition: cover_tree.hpp:309
void serialize(Archive &ar, const unsigned int)
Serialize the tree.
CoverTree *& ChildPtr(const size_t index)
Definition: cover_tree.hpp:282
A single-tree cover tree traverser; see single_tree_traverser.hpp for implementation.
Definition: cover_tree.hpp:257
ElemType ParentDistance() const
Get the distance to the parent.
Definition: cover_tree.hpp:393
ElemType MinDistance(const CoverTree &other) const
Return the minimum distance to another node.
size_t Point(const size_t) const
For compatibility with other trees; the argument is ignored.
Definition: cover_tree.hpp:272
const MatType & Dataset() const
Get a reference to the dataset.
Definition: cover_tree.hpp:267
size_t NumChildren() const
Get the number of children.
Definition: cover_tree.hpp:285
ElemType FurthestPointDistance() const
Get the distance to the furthest point. This is always 0 for cover trees.
Definition: cover_tree.hpp:398
ElemType & Base()
Modify the base; don&#39;t do this, you&#39;ll break everything.
Definition: cover_tree.hpp:306
Definition of the Range class, which represents a simple range with a lower and upper bound...
CoverTree & Child(const size_t index)
Modify a particular child node.
Definition: cover_tree.hpp:280
ElemType MinimumBoundDistance() const
Get the minimum distance from the center to any bound edge (this is the same as furthestDescendantDis...
Definition: cover_tree.hpp:409
MetricType & Metric() const
Get the instantiated metric.
Definition: cover_tree.hpp:418
size_t Descendant(const size_t index) const
Get the index of a particular descendant point.
ElemType FurthestDescendantDistance() const
Get the distance from the center of the node to the furthest descendant.
Definition: cover_tree.hpp:401
math::RangeType< ElemType > RangeDistance(const CoverTree &other) const
Return the minimum and maximum distance to another node.
A cover tree is a tree specifically designed to speed up nearest-neighbor computation in high-dimensi...
Definition: cover_tree.hpp:99
ElemType & ParentDistance()
Modify the distance to the parent.
Definition: cover_tree.hpp:395
const CoverTree & Child(const size_t index) const
Get a particular child node.
Definition: cover_tree.hpp:278
If value == true, then VecType is some sort of Armadillo vector or subview.
Definition: arma_traits.hpp:35
size_t NumDescendants() const
Get the number of descendant points.