similarity_interpolation.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_CF_SIMILARITY_INTERPOLATION_HPP
13 #define MLPACK_METHODS_CF_SIMILARITY_INTERPOLATION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace cf {
19 
42 {
43  public:
44  // Empty onstructor.
46 
50  SimilarityInterpolation(const arma::sp_mat& /* cleanedData */) { }
51 
65  template <typename VectorType,
66  typename DecompositionPolicy>
67  void GetWeights(VectorType&& weights,
68  const DecompositionPolicy& /* decomposition */,
69  const size_t /* queryUser */,
70  const arma::Col<size_t>& neighbors,
71  const arma::vec& similarities,
72  const arma::sp_mat& /* cleanedData */)
73  {
74  if (similarities.n_elem == 0)
75  {
76  Log::Fatal << "Require: similarities.n_elem > 0. There should be at "
77  << "least one neighbor!" << std::endl;
78  }
79 
80  if (weights.n_elem != neighbors.n_elem)
81  {
82  Log::Fatal << "The size of the first parameter (weights) should "
83  << "be set to the number of neighbors before calling GetWeights()."
84  << std::endl;
85  }
86 
87  double similaritiesSum = arma::sum(similarities);
88  if (std::fabs(similaritiesSum) < 1e-14)
89  {
90  weights.fill(1.0 / similarities.n_elem);
91  }
92  else
93  {
94  weights = similarities / similaritiesSum;
95  }
96  }
97 };
98 
99 } // namespace cf
100 } // namespace mlpack
101 
102 #endif
void GetWeights(VectorType &&weights, const DecompositionPolicy &, const size_t, const arma::Col< size_t > &neighbors, const arma::vec &similarities, const arma::sp_mat &)
Interpolation weights are computed as normalized similarities.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
SimilarityInterpolation(const arma::sp_mat &)
This constructor is needed for interface consistency.
With SimilarityInterpolation, interpolation weights are based on similarities between query user and ...
static util::PrefixedOutStream Fatal
Definition: log.hpp:105