ns_model.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
32 template<typename SortPolicy,
33  template<typename TreeMetricType,
34  typename TreeStatType,
35  typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
38  arma::mat,
39  TreeType,
41  NeighborSearchStat<SortPolicy>,
42  arma::mat>::template DualTreeTraverser>;
43 
48 class MonoSearchVisitor : public boost::static_visitor<void>
49 {
50  private:
52  const size_t k;
54  arma::Mat<size_t>& neighbors;
56  arma::mat& distances;
57 
58  public:
60  template<typename NSType>
61  void operator()(NSType* ns) const;
62 
64  MonoSearchVisitor(const size_t k,
65  arma::Mat<size_t>& neighbors,
66  arma::mat& distances) :
67  k(k),
68  neighbors(neighbors),
69  distances(distances)
70  {};
71 };
72 
79 template<typename SortPolicy>
80 class BiSearchVisitor : public boost::static_visitor<void>
81 {
82  private:
84  const arma::mat& querySet;
86  const size_t k;
88  arma::Mat<size_t>& neighbors;
90  arma::mat& distances;
92  const size_t leafSize;
94  const double tau;
96  const double rho;
97 
99  template<typename NSType>
100  void SearchLeaf(NSType* ns) const;
101 
102  public:
104  template<template<typename TreeMetricType,
105  typename TreeStatType,
106  typename TreeMatType> class TreeType>
108 
110  template<template<typename TreeMetricType,
111  typename TreeStatType,
112  typename TreeMatType> class TreeType>
113  void operator()(NSTypeT<TreeType>* ns) const;
114 
116  void operator()(NSTypeT<tree::KDTree>* ns) const;
117 
119  void operator()(NSTypeT<tree::BallTree>* ns) const;
120 
122  void operator()(SpillKNN* ns) const;
123 
125  void operator()(NSTypeT<tree::Octree>* ns) const;
126 
128  BiSearchVisitor(const arma::mat& querySet,
129  const size_t k,
130  arma::Mat<size_t>& neighbors,
131  arma::mat& distances,
132  const size_t leafSize,
133  const double tau,
134  const double rho);
135 };
136 
143 template<typename SortPolicy>
144 class TrainVisitor : public boost::static_visitor<void>
145 {
146  private:
148  arma::mat&& referenceSet;
150  size_t leafSize;
152  const double tau;
154  const double rho;
155 
157  template<typename NSType>
158  void TrainLeaf(NSType* ns) const;
159 
160  public:
162  template<template<typename TreeMetricType,
163  typename TreeStatType,
164  typename TreeMatType> class TreeType>
166 
168  template<template<typename TreeMetricType,
169  typename TreeStatType,
170  typename TreeMatType> class TreeType>
171  void operator()(NSTypeT<TreeType>* ns) const;
172 
174  void operator()(NSTypeT<tree::KDTree>* ns) const;
175 
177  void operator()(NSTypeT<tree::BallTree>* ns) const;
178 
180  void operator()(SpillKNN* ns) const;
181 
183  void operator()(NSTypeT<tree::Octree>* ns) const;
184 
187  TrainVisitor(arma::mat&& referenceSet,
188  const size_t leafSize,
189  const double tau,
190  const double rho);
191 };
192 
196 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode&>
197 {
198  public:
200  template<typename NSType>
201  NeighborSearchMode& operator()(NSType* ns) const;
202 };
203 
207 class EpsilonVisitor : public boost::static_visitor<double&>
208 {
209  public:
211  template<typename NSType>
212  double& operator()(NSType *ns) const;
213 };
214 
218 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
219 {
220  public:
222  template<typename NSType>
223  const arma::mat& operator()(NSType *ns) const;
224 };
225 
229 class DeleteVisitor : public boost::static_visitor<void>
230 {
231  public:
233  template<typename NSType>
234  void operator()(NSType *ns) const;
235 };
236 
247 template<typename SortPolicy>
248 class NSModel
249 {
250  public:
253  {
268  OCTREE
269  };
270 
271  private:
273  TreeTypes treeType;
274 
276  size_t leafSize;
277 
279  double tau;
281  double rho;
282 
284  bool randomBasis;
286  arma::mat q;
287 
293  boost::variant<NSType<SortPolicy, tree::KDTree>*,
305  SpillKNN*,
308 
309  public:
318  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
319 
325  NSModel(const NSModel& other);
326 
332  NSModel(NSModel&& other);
333 
339  NSModel& operator=(const NSModel& other);
340 
346  NSModel& operator=(NSModel&& other);
347 
349  ~NSModel();
350 
352  template<typename Archive>
353  void serialize(Archive& ar, const unsigned int /* version */);
354 
356  const arma::mat& Dataset() const;
357 
359  NeighborSearchMode SearchMode() const;
360  NeighborSearchMode& SearchMode();
361 
363  double Epsilon() const;
364  double& Epsilon();
365 
367  size_t LeafSize() const { return leafSize; }
368  size_t& LeafSize() { return leafSize; }
369 
371  double Tau() const { return tau; }
372  double& Tau() { return tau; }
373 
375  double Rho() const { return rho; }
376  double& Rho() { return rho; }
377 
379  TreeTypes TreeType() const { return treeType; }
380  TreeTypes& TreeType() { return treeType; }
381 
383  bool RandomBasis() const { return randomBasis; }
384  bool& RandomBasis() { return randomBasis; }
385 
387  void BuildModel(arma::mat&& referenceSet,
388  const size_t leafSize,
389  const NeighborSearchMode searchMode,
390  const double epsilon = 0);
391 
393  void Search(arma::mat&& querySet,
394  const size_t k,
395  arma::Mat<size_t>& neighbors,
396  arma::mat& distances);
397 
399  void Search(const size_t k,
400  arma::Mat<size_t>& neighbors,
401  arma::mat& distances);
402 
404  std::string TreeName() const;
405 };
406 
407 } // namespace neighbor
408 } // namespace mlpack
409 
411 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
413 
414 // Include implementation.
415 #include "ns_model_impl.hpp"
416 
417 #endif
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:64
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:207
.hpp
Definition: add_to_po.hpp:21
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:383
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:252
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::NSModel< SortPolicy >, 1)
Set the serialization version of the NSModel class.
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:218
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:196
The NeighborSearch class is a template class for performing distance-based neighbor searches...
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:379
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, TreeType< metric::EuclideanDistance, NeighborSearchStat< SortPolicy >, arma::mat >::template DualTreeTraverser > NSType
Alias template for euclidean neighbor search.
Definition: ns_model.hpp:42
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:367
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:80
TreeTypes & TreeType()
Definition: ns_model.hpp:380
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:248
double Tau() const
Expose tau.
Definition: ns_model.hpp:371
TrainVisitor sets the reference set to a new reference set on the given NSType.
void operator()(NSType *ns) const
Perform monochromatic nearest neighbor search.
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:48
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:229
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
double Rho() const
Expose rho.
Definition: ns_model.hpp:375
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.