regression_interpolation.hpp
Go to the documentation of this file.
1 
12 #ifndef MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP
13 #define MLPACK_METHODS_CF_REGRESSION_INTERPOLATION_HPP
14 
15 #include <mlpack/prereqs.hpp>
16 
17 namespace mlpack {
18 namespace cf {
19 
57 {
58  public:
63 
69  RegressionInterpolation(const arma::sp_mat& cleanedData)
70  {
71  const size_t userNum = cleanedData.n_cols;
72  a.set_size(userNum, userNum);
73  b.set_size(userNum, userNum);
74  }
75 
93  template <typename VectorType,
94  typename DecompositionPolicy>
95  void GetWeights(VectorType&& weights,
96  const DecompositionPolicy& decomposition,
97  const size_t queryUser,
98  const arma::Col<size_t>& neighbors,
99  const arma::vec& /* similarities*/,
100  const arma::sp_mat& cleanedData)
101  {
102  if (weights.n_elem != neighbors.n_elem)
103  {
104  Log::Fatal << "The size of the first parameter (weights) should "
105  << "be set to the number of neighbors before calling GetWeights()."
106  << std::endl;
107  }
108 
109  const arma::mat& w = decomposition.W();
110  const arma::mat& h = decomposition.H();
111  const size_t itemNum = cleanedData.n_rows;
112  const size_t neighborNum = neighbors.size();
113 
114  // Coeffcients of the linear equations used to compute weights.
115  arma::mat coeff(neighborNum, neighborNum);
116  // Constant terms of the linear equations used to compute weights.
117  arma::vec constant(neighborNum);
118 
119  arma::vec userRating(cleanedData.col(queryUser));
120  const size_t support = arma::accu(userRating != 0);
121 
122  // If user has no rating at all, average interpolation is used.
123  if (support == 0)
124  {
125  weights.fill(1.0 / neighbors.n_elem);
126  return;
127  }
128 
129  for (size_t i = 0; i < neighborNum; ++i)
130  {
131  // Calculate coefficient.
132  arma::vec iPrediction;
133  for (size_t j = i; j < neighborNum; ++j)
134  {
135  if (a(neighbors(i), neighbors(j)) != 0)
136  {
137  // The coefficient has already been cached.
138  coeff(i, j) = a(neighbors(i), neighbors(j));
139  coeff(j, i) = coeff(i, j);
140  }
141  else
142  {
143  // Calculate the coefficient.
144  if (iPrediction.size() == 0)
145  // Avoid recalculation of iPrediction.
146  iPrediction = w * h.col(neighbors(i));
147  arma::vec jPrediction = w * h.col(neighbors(j));
148  coeff(i, j) = arma::dot(iPrediction, jPrediction) / itemNum;
149  if (coeff(i, j) == 0)
150  coeff(i, j) = std::numeric_limits<double>::min();
151  coeff(j, i) = coeff(i, j);
152  // Cache calcualted coefficient.
153  a(neighbors(i), neighbors(j)) = coeff(i, j);
154  a(neighbors(j), neighbors(i)) = coeff(i, j);
155  }
156  }
157 
158  // Calculate constant terms.
159  if (b(neighbors(i), queryUser) != 0)
160  // The constant term has already been cached.
161  constant(i) = b(neighbors(i), queryUser);
162  else
163  {
164  // Calcuate the constant term.
165  if (iPrediction.size() == 0)
166  // Avoid recalculation of iPrediction.
167  iPrediction = w * h.col(neighbors(i));
168  constant(i) = arma::dot(iPrediction, userRating) / support;
169  if (constant(i) == 0)
170  constant(i) = std::numeric_limits<double>::min();
171  // Cache calculated constant term.
172  b(neighbors(i), queryUser) = constant(i);
173  }
174  }
175  weights = arma::solve(coeff, constant);
176  }
177 
178  private:
180  arma::sp_mat a;
182  arma::sp_mat b;
183 };
184 
185 } // namespace cf
186 } // namespace mlpack
187 
188 #endif
void GetWeights(VectorType &&weights, const DecompositionPolicy &decomposition, const size_t queryUser, const arma::Col< size_t > &neighbors, const arma::vec &, const arma::sp_mat &cleanedData)
The regression-based interpolation problem can be solved by a linear system of equations.
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
RegressionInterpolation(const arma::sp_mat &cleanedData)
Use cleanedData to perform necessary preprocessing.
Implementation of regression-based interpolation method.
constexpr T const & min(T const &lhs, T const &rhs)
Definition: algorithm.hpp:69
static util::PrefixedOutStream Fatal
Definition: log.hpp:105