mlpack
a scalable c++ machine learning library
mlpack: src/mlpack/methods/emst/dtb.hpp Source File

dtb.hpp

Go to the documentation of this file.
00001 
00035 #ifndef __mlpack_METHODS_EMST_DTB_HPP
00036 #define __mlpack_METHODS_EMST_DTB_HPP
00037 
00038 #include "edge_pair.hpp"
00039 
00040 #include <mlpack/core.hpp>
00041 #include <mlpack/core/metrics/lmetric.hpp>
00042 
00043 #include <mlpack/core/tree/binary_space_tree.hpp>
00044 
00045 namespace mlpack {
00046 namespace emst  {
00047 
00052 class DTBStat
00053 {
00054  private:
00057   double maxNeighborDistance;
00062   int componentMembership;
00063 
00064  public:
00069   DTBStat();
00070 
00078   template<typename TreeType>
00079   DTBStat(const TreeType& node);
00080 
00082   double MaxNeighborDistance() const { return maxNeighborDistance; }
00084   double& MaxNeighborDistance() { return maxNeighborDistance; }
00085 
00087   int ComponentMembership() const { return componentMembership; }
00089   int& ComponentMembership() { return componentMembership; }
00090 
00091 }; // class DTBStat
00092 
00131 template<
00132   typename MetricType = metric::SquaredEuclideanDistance,
00133   typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat>
00134 >
00135 class DualTreeBoruvka
00136 {
00137  private:
00139   typename TreeType::Mat dataCopy;
00141   typename TreeType::Mat& data;
00142 
00144   TreeType* tree;
00146   bool ownTree;
00147 
00149   bool naive;
00150 
00152   std::vector<EdgePair> edges; // We must use vector with non-numerical types.
00153 
00155   UnionFind connections;
00156 
00158   std::vector<size_t> oldFromNew;
00160   arma::Col<size_t> neighborsInComponent;
00162   arma::Col<size_t> neighborsOutComponent;
00164   arma::vec neighborsDistances;
00165 
00167   double totalDist;
00168 
00170   MetricType metric;
00171 
00172   // For sorting the edge list after the computation.
00173   struct SortEdgesHelper
00174   {
00175     bool operator()(const EdgePair& pairA, const EdgePair& pairB)
00176     {
00177       return (pairA.Distance() < pairB.Distance());
00178     }
00179   } SortFun;
00180 
00181  public:
00190   DualTreeBoruvka(const typename TreeType::Mat& dataset,
00191                   const bool naive = false,
00192                   const size_t leafSize = 1,
00193                   const MetricType metric = MetricType());
00194 
00212   DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset,
00213                   const MetricType metric = MetricType());
00214 
00218   ~DualTreeBoruvka();
00219 
00229   void ComputeMST(arma::mat& results);
00230 
00231  private:
00235   void AddEdge(const size_t e1, const size_t e2, const double distance);
00236 
00240   void AddAllEdges();
00241 
00245   void EmitResults(arma::mat& results);
00246 
00251   void CleanupHelper(TreeType* tree);
00252 
00256   void Cleanup();
00257 
00258 }; // class DualTreeBoruvka
00259 
00260 }; // namespace emst
00261 }; // namespace mlpack
00262 
00263 #include "dtb_impl.hpp"
00264 
00265 #endif // __mlpack_METHODS_EMST_DTB_HPP