mlpack
a scalable c++ machine learning library
docs
mlpack: src/mlpack/methods/emst/dtb.hpp Source File

dtb.hpp

Go to the documentation of this file.
00001 
00027 #ifndef __mlpack_METHODS_EMST_DTB_HPP
00028 #define __mlpack_METHODS_EMST_DTB_HPP
00029 
00030 #include "dtb_stat.hpp"
00031 #include "edge_pair.hpp"
00032 
00033 #include <mlpack/core.hpp>
00034 #include <mlpack/core/metrics/lmetric.hpp>
00035 
00036 #include <mlpack/core/tree/binary_space_tree.hpp>
00037 
00038 namespace mlpack {
00039 namespace emst  {
00040 
00079 template<
00080   typename MetricType = metric::EuclideanDistance,
00081   typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
00082 >
00083 class DualTreeBoruvka
00084 {
00085  private:
00087   typename TreeType::Mat dataCopy;
00089   const typename TreeType::Mat& data;
00090 
00092   TreeType* tree;
00094   bool ownTree;
00095 
00097   bool naive;
00098 
00100   std::vector<EdgePair> edges; // We must use vector with non-numerical types.
00101 
00103   UnionFind connections;
00104 
00106   std::vector<size_t> oldFromNew;
00108   arma::Col<size_t> neighborsInComponent;
00110   arma::Col<size_t> neighborsOutComponent;
00112   arma::vec neighborsDistances;
00113 
00115   double totalDist;
00116 
00118   MetricType metric;
00119 
00121   struct SortEdgesHelper
00122   {
00123     bool operator()(const EdgePair& pairA, const EdgePair& pairB)
00124     {
00125       return (pairA.Distance() < pairB.Distance());
00126     }
00127   } SortFun;
00128 
00129  public:
00138   DualTreeBoruvka(const typename TreeType::Mat& dataset,
00139                   const bool naive = false,
00140                   const MetricType metric = MetricType());
00141 
00159   DualTreeBoruvka(TreeType* tree,
00160                   const typename TreeType::Mat& dataset,
00161                   const MetricType metric = MetricType());
00162 
00166   ~DualTreeBoruvka();
00167 
00177   void ComputeMST(arma::mat& results);
00178 
00182   std::string ToString() const;
00183 
00184  private:
00188   void AddEdge(const size_t e1, const size_t e2, const double distance);
00189 
00193   void AddAllEdges();
00194 
00198   void EmitResults(arma::mat& results);
00199 
00204   void CleanupHelper(TreeType* tree);
00205 
00209   void Cleanup();
00210 
00211 }; // class DualTreeBoruvka
00212 
00213 }; // namespace emst
00214 }; // namespace mlpack
00215 
00216 #include "dtb_impl.hpp"
00217 
00218 #endif // __mlpack_METHODS_EMST_DTB_HPP