lsh_search.hpp
Go to the documentation of this file.
1 
43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
45 
46 #include <mlpack/prereqs.hpp>
47 
50 
51 #include <queue>
52 
53 namespace mlpack {
54 namespace neighbor {
55 
63 template<typename SortPolicy = NearestNeighborSort>
64 class LSHSearch
65 {
66  public:
89  LSHSearch(arma::mat referenceSet,
90  const arma::cube& projections,
91  const double hashWidth = 0.0,
92  const size_t secondHashSize = 99901,
93  const size_t bucketSize = 500);
94 
117  LSHSearch(arma::mat referenceSet,
118  const size_t numProj,
119  const size_t numTables,
120  const double hashWidth = 0.0,
121  const size_t secondHashSize = 99901,
122  const size_t bucketSize = 500);
123 
128  LSHSearch();
129 
135  LSHSearch(const LSHSearch& other);
136 
142  LSHSearch(LSHSearch&& other);
143 
149  LSHSearch& operator=(const LSHSearch& other);
150 
156  LSHSearch& operator=(LSHSearch&& other);
157 
183  void Train(arma::mat referenceSet,
184  const size_t numProj,
185  const size_t numTables,
186  const double hashWidth = 0.0,
187  const size_t secondHashSize = 99901,
188  const size_t bucketSize = 500,
189  const arma::cube& projection = arma::cube());
190 
212  void Search(const arma::mat& querySet,
213  const size_t k,
214  arma::Mat<size_t>& resultingNeighbors,
215  arma::mat& distances,
216  const size_t numTablesToSearch = 0,
217  const size_t T = 0);
218 
237  void Search(const size_t k,
238  arma::Mat<size_t>& resultingNeighbors,
239  arma::mat& distances,
240  const size_t numTablesToSearch = 0,
241  size_t T = 0);
242 
252  static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
253  const arma::Mat<size_t>& realNeighbors);
254 
260  template<typename Archive>
261  void serialize(Archive& ar, const unsigned int version);
262 
264  size_t DistanceEvaluations() const { return distanceEvaluations; }
266  size_t& DistanceEvaluations() { return distanceEvaluations; }
267 
269  const arma::mat& ReferenceSet() const { return referenceSet; }
270 
272  size_t NumProjections() const { return projections.n_slices; }
273 
275  const arma::mat& Offsets() const { return offsets; }
276 
278  const arma::vec& SecondHashWeights() const { return secondHashWeights; }
279 
281  size_t BucketSize() const { return bucketSize; }
282 
284  const std::vector<arma::Col<size_t>>& SecondHashTable() const
285  { return secondHashTable; }
286 
288  const arma::cube& Projections() { return projections; }
289 
291  void Projections(const arma::cube& projTables)
292  {
293  // Simply call Train() with the given projection tables.
294  Train(referenceSet, numProj, numTables, hashWidth, secondHashSize,
295  bucketSize, projTables);
296  }
297 
298  private:
314  template<typename VecType>
315  void ReturnIndicesFromTable(const VecType& queryPoint,
316  arma::uvec& referenceIndices,
317  size_t numTablesToSearch,
318  const size_t T) const;
319 
333  void BaseCase(const size_t queryIndex,
334  const arma::uvec& referenceIndices,
335  const size_t k,
336  arma::Mat<size_t>& neighbors,
337  arma::mat& distances) const;
338 
353  void BaseCase(const size_t queryIndex,
354  const arma::uvec& referenceIndices,
355  const size_t k,
356  const arma::mat& querySet,
357  arma::Mat<size_t>& neighbors,
358  arma::mat& distances) const;
359 
374  void GetAdditionalProbingBins(const arma::vec& queryCode,
375  const arma::vec& queryCodeNotFloored,
376  const size_t T,
377  arma::mat& additionalProbingBins) const;
378 
386  double PerturbationScore(const std::vector<bool>& A,
387  const arma::vec& scores) const;
388 
396  bool PerturbationShift(std::vector<bool>& A) const;
397 
406  bool PerturbationExpand(std::vector<bool>& A) const;
407 
415  bool PerturbationValid(const std::vector<bool>& A) const;
416 
418  arma::mat referenceSet;
419 
421  size_t numProj;
423  size_t numTables;
424 
426  arma::cube projections; // should be [numProj x dims] x numTables slices
427 
429  arma::mat offsets; // should be numProj x numTables
430 
432  double hashWidth;
433 
435  size_t secondHashSize;
436 
438  arma::vec secondHashWeights;
439 
441  size_t bucketSize;
442 
445  std::vector<arma::Col<size_t>> secondHashTable;
446 
449  arma::Col<size_t> bucketContentSize;
450 
453  arma::Col<size_t> bucketRowInHashTable;
454 
456  size_t distanceEvaluations;
457 
459  typedef std::pair<double, size_t> Candidate;
460 
462  struct CandidateCmp {
463  bool operator()(const Candidate& c1, const Candidate& c2)
464  {
465  return !SortPolicy::IsBetter(c2.first, c1.first);
466  };
467  };
468 
470  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
471  CandidateList;
472 }; // class LSHSearch
473 
474 } // namespace neighbor
475 } // namespace mlpack
476 
478 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
480 
481 // Include implementation.
482 #include "lsh_search_impl.hpp"
483 
484 #endif
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
Definition: lsh_search.hpp:284
const arma::mat & Offsets() const
Get the offsets &#39;b&#39; for each of the projections. (One &#39;b&#39; per column.)
Definition: lsh_search.hpp:275
.hpp
Definition: add_to_po.hpp:21
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
Definition: lsh_search.hpp:264
const arma::cube & Projections()
Get the projection tables.
Definition: lsh_search.hpp:288
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
void serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
void Train(arma::mat referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:64
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::LSHSearch< SortPolicy >, 1)
Set the serialization version of the LSHSearch class.
size_t NumProjections() const
Get the number of projections.
Definition: lsh_search.hpp:272
const arma::mat & ReferenceSet() const
Return the reference dataset.
Definition: lsh_search.hpp:269
size_t BucketSize() const
Get the bucket size of the second hash.
Definition: lsh_search.hpp:281
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
Definition: lsh_search.hpp:291
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
Definition: lsh_search.hpp:266
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
Definition: lsh_search.hpp:278