43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP 63 template<
typename SortPolicy = NearestNeighborSort>
90 const arma::cube& projections,
91 const double hashWidth = 0.0,
92 const size_t secondHashSize = 99901,
93 const size_t bucketSize = 500);
118 const size_t numProj,
119 const size_t numTables,
120 const double hashWidth = 0.0,
121 const size_t secondHashSize = 99901,
122 const size_t bucketSize = 500);
183 void Train(arma::mat referenceSet,
184 const size_t numProj,
185 const size_t numTables,
186 const double hashWidth = 0.0,
187 const size_t secondHashSize = 99901,
188 const size_t bucketSize = 500,
189 const arma::cube& projection = arma::cube());
212 void Search(
const arma::mat& querySet,
214 arma::Mat<size_t>& resultingNeighbors,
215 arma::mat& distances,
216 const size_t numTablesToSearch = 0,
237 void Search(
const size_t k,
238 arma::Mat<size_t>& resultingNeighbors,
239 arma::mat& distances,
240 const size_t numTablesToSearch = 0,
252 static double ComputeRecall(
const arma::Mat<size_t>& foundNeighbors,
253 const arma::Mat<size_t>& realNeighbors);
260 template<
typename Archive>
261 void serialize(Archive& ar,
const unsigned int version);
275 const arma::mat&
Offsets()
const {
return offsets; }
285 {
return secondHashTable; }
294 Train(referenceSet, numProj, numTables, hashWidth, secondHashSize,
295 bucketSize, projTables);
314 template<
typename VecType>
315 void ReturnIndicesFromTable(
const VecType& queryPoint,
316 arma::uvec& referenceIndices,
317 size_t numTablesToSearch,
318 const size_t T)
const;
333 void BaseCase(
const size_t queryIndex,
334 const arma::uvec& referenceIndices,
336 arma::Mat<size_t>& neighbors,
337 arma::mat& distances)
const;
353 void BaseCase(
const size_t queryIndex,
354 const arma::uvec& referenceIndices,
356 const arma::mat& querySet,
357 arma::Mat<size_t>& neighbors,
358 arma::mat& distances)
const;
374 void GetAdditionalProbingBins(
const arma::vec& queryCode,
375 const arma::vec& queryCodeNotFloored,
377 arma::mat& additionalProbingBins)
const;
386 double PerturbationScore(
const std::vector<bool>& A,
387 const arma::vec& scores)
const;
396 bool PerturbationShift(std::vector<bool>& A)
const;
406 bool PerturbationExpand(std::vector<bool>& A)
const;
415 bool PerturbationValid(
const std::vector<bool>& A)
const;
418 arma::mat referenceSet;
426 arma::cube projections;
435 size_t secondHashSize;
438 arma::vec secondHashWeights;
445 std::vector<arma::Col<size_t>> secondHashTable;
449 arma::Col<size_t> bucketContentSize;
453 arma::Col<size_t> bucketRowInHashTable;
456 size_t distanceEvaluations;
459 typedef std::pair<double, size_t> Candidate;
462 struct CandidateCmp {
463 bool operator()(
const Candidate& c1,
const Candidate& c2)
465 return !SortPolicy::IsBetter(c2.first, c1.first);
470 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
482 #include "lsh_search_impl.hpp" void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
const arma::mat & Offsets() const
Get the offsets 'b' for each of the projections. (One 'b' per column.)
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
const arma::cube & Projections()
Get the projection tables.
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
void serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
void Train(arma::mat referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
BOOST_TEMPLATE_CLASS_VERSION(template< typename SortPolicy >, mlpack::neighbor::LSHSearch< SortPolicy >, 1)
Set the serialization version of the LSHSearch class.
size_t NumProjections() const
Get the number of projections.
const arma::mat & ReferenceSet() const
Return the reference dataset.
size_t BucketSize() const
Get the bucket size of the second hash.
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
LSHSearch & operator=(const LSHSearch &other)
Copy the given LSH model.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.