dtb.hpp
Go to the documentation of this file.
00001 00035 #ifndef __mlpack_METHODS_EMST_DTB_HPP 00036 #define __mlpack_METHODS_EMST_DTB_HPP 00037 00038 #include "edge_pair.hpp" 00039 00040 #include <mlpack/core.hpp> 00041 #include <mlpack/core/metrics/lmetric.hpp> 00042 00043 #include <mlpack/core/tree/binary_space_tree.hpp> 00044 00045 namespace mlpack { 00046 namespace emst { 00047 00052 class DTBStat 00053 { 00054 private: 00057 double maxNeighborDistance; 00062 int componentMembership; 00063 00064 public: 00069 DTBStat(); 00070 00078 template<typename TreeType> 00079 DTBStat(const TreeType& node); 00080 00082 double MaxNeighborDistance() const { return maxNeighborDistance; } 00084 double& MaxNeighborDistance() { return maxNeighborDistance; } 00085 00087 int ComponentMembership() const { return componentMembership; } 00089 int& ComponentMembership() { return componentMembership; } 00090 00091 }; // class DTBStat 00092 00131 template< 00132 typename MetricType = metric::SquaredEuclideanDistance, 00133 typename TreeType = tree::BinarySpaceTree<bound::HRectBound<2>, DTBStat> 00134 > 00135 class DualTreeBoruvka 00136 { 00137 private: 00139 typename TreeType::Mat dataCopy; 00141 typename TreeType::Mat& data; 00142 00144 TreeType* tree; 00146 bool ownTree; 00147 00149 bool naive; 00150 00152 std::vector<EdgePair> edges; // We must use vector with non-numerical types. 00153 00155 UnionFind connections; 00156 00158 std::vector<size_t> oldFromNew; 00160 arma::Col<size_t> neighborsInComponent; 00162 arma::Col<size_t> neighborsOutComponent; 00164 arma::vec neighborsDistances; 00165 00167 double totalDist; 00168 00170 MetricType metric; 00171 00172 // For sorting the edge list after the computation. 00173 struct SortEdgesHelper 00174 { 00175 bool operator()(const EdgePair& pairA, const EdgePair& pairB) 00176 { 00177 return (pairA.Distance() < pairB.Distance()); 00178 } 00179 } SortFun; 00180 00181 public: 00190 DualTreeBoruvka(const typename TreeType::Mat& dataset, 00191 const bool naive = false, 00192 const size_t leafSize = 1, 00193 const MetricType metric = MetricType()); 00194 00212 DualTreeBoruvka(TreeType* tree, const typename TreeType::Mat& dataset, 00213 const MetricType metric = MetricType()); 00214 00218 ~DualTreeBoruvka(); 00219 00229 void ComputeMST(arma::mat& results); 00230 00231 private: 00235 void AddEdge(const size_t e1, const size_t e2, const double distance); 00236 00240 void AddAllEdges(); 00241 00245 void EmitResults(arma::mat& results); 00246 00251 void CleanupHelper(TreeType* tree); 00252 00256 void Cleanup(); 00257 00258 }; // class DualTreeBoruvka 00259 00260 }; // namespace emst 00261 }; // namespace mlpack 00262 00263 #include "dtb_impl.hpp" 00264 00265 #endif // __mlpack_METHODS_EMST_DTB_HPP

1.7.1