14 #ifndef MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 15 #define MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP 32 template<
typename SortPolicy,
typename MetricType,
typename TreeType>
58 const arma::mat& querySet,
62 const double alpha = 0.95,
63 const bool naive =
false,
64 const bool sampleAtLeaves =
false,
65 const bool firstLeafExact =
false,
66 const size_t singleSampleLimit = 20,
67 const bool sameSet =
false);
76 void GetResults(arma::Mat<size_t>& neighbors, arma::mat& distances);
85 double BaseCase(
const size_t queryIndex,
const size_t referenceIndex);
109 double Score(
const size_t queryIndex, TreeType& referenceNode);
134 double Score(
const size_t queryIndex,
135 TreeType& referenceNode,
136 const double baseCaseResult);
155 double Rescore(
const size_t queryIndex,
156 TreeType& referenceNode,
157 const double oldScore);
177 double Score(TreeType& queryNode, TreeType& referenceNode);
199 double Score(TreeType& queryNode,
200 TreeType& referenceNode,
201 const double baseCaseResult);
225 double Rescore(TreeType& queryNode,
226 TreeType& referenceNode,
227 const double oldScore);
233 if (numSamplesMade.n_elem == 0)
236 return arma::sum(numSamplesMade);
246 const arma::mat& referenceSet;
249 const arma::mat& querySet;
252 typedef std::pair<double, size_t> Candidate;
255 struct CandidateCmp {
256 bool operator()(
const Candidate& c1,
const Candidate& c2)
258 return !SortPolicy::IsBetter(c2.first, c1.first);
263 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
267 std::vector<CandidateList> candidates;
282 size_t singleSampleLimit;
285 size_t numSamplesReqd;
288 arma::Col<size_t> numSamplesMade;
291 double samplingRatio;
294 size_t numDistComputations;
299 TraversalInfoType traversalInfo;
308 void InsertNeighbor(
const size_t queryIndex,
309 const size_t neighbor,
310 const double distance);
315 double Score(
const size_t queryIndex,
316 TreeType& referenceNode,
317 const double distance,
318 const double bestDistance);
323 double Score(TreeType& queryNode,
324 TreeType& referenceNode,
325 const double distance,
326 const double bestDistance);
329 "must provide a unique number of descendants points.");
336 #include "ra_search_rules_impl.hpp" 338 #endif // MLPACK_METHODS_RANN_RA_SEARCH_RULES_HPP RASearchRules(const arma::mat &referenceSet, const arma::mat &querySet, const size_t k, MetricType &metric, const double tau=5, const double alpha=0.95, const bool naive=false, const bool sampleAtLeaves=false, const bool firstLeafExact=false, const size_t singleSampleLimit=20, const bool sameSet=false)
Construct the RASearchRules object.
The TraversalInfo class holds traversal information which is used in dual-tree (and single-tree) trav...
tree::TraversalInfo< TreeType > TraversalInfoType
const TraversalInfoType & TraversalInfo() const
double Rescore(const size_t queryIndex, TreeType &referenceNode, const double oldScore)
Re-evaluate the score for recursion order.
double Score(const size_t queryIndex, TreeType &referenceNode)
Get the score for recursion order.
size_t NumDistComputations()
size_t NumEffectiveSamples()
see subsection cli_alt_reg_tut Alternate DET regularization The usual regularized error f $R_ alpha(t)\f$ of a node \f $t\f$ is given by
The TreeTraits class provides compile-time information on the characteristics of a given tree type...
TraversalInfoType & TraversalInfo()
double BaseCase(const size_t queryIndex, const size_t referenceIndex)
Get the distance from the query point to the reference point.
void GetResults(arma::Mat< size_t > &neighbors, arma::mat &distances)
Store the list of candidates for each query point in the given matrices.
The RASearchRules class is a template helper class used by RASearch class when performing rank-approx...