[mlpack] [mlpack/mlpack] adds GammaDistribution::Train(observations, probabilities) (#834)

Ryan Curtin notifications at github.com
Wed Dec 21 16:48:32 EST 2016


rcurtin commented on this pull request.

Tests look good, I think they just need a few minor modifications.  Let me know what you think of my comments.  It would probably be a good idea to split these into multiple tests, too; if you prefer I can do that after merge.

Also, there are a few places where the code doesn't follow the mlpack style guide... could you fix those please? :)  Specifically, `for(...)` -> `for (...)`, `i%2` -> `i % 2` (horizontal whitespace), and no variables names with underscores (`all_probabilities_1` -> `allProbabilities1`, etc.).  Thanks!

Lastly, do you want to add your name and email to `src/mlpack/core.hpp` and `COPYRIGHT.txt`?

> +  for(size_t i = 0; i < N; i++)
+    probabilities(i) = prob(generator);
+
+  // fit results with probabilities and data
+  GammaDistribution gDist;
+  gDist.Train(rdata, probabilities);
+
+  // fit results with only data
+  GammaDistribution gDist2;
+  gDist2.Train(rdata);
+
+  BOOST_REQUIRE_CLOSE(gDist2.Alpha(0), gDist.Alpha(0), 10);
+  BOOST_REQUIRE_CLOSE(gDist2.Beta(0), gDist.Beta(0), 10);
+
+  BOOST_REQUIRE_CLOSE(alphaReal, gDist.Alpha(0), 10);
+  BOOST_REQUIRE_CLOSE(betaReal, gDist.Beta(0), 10);

We can probably check these with much closer tolerances.  I'd suggest 1e-5 instead of 10.  Shouldn't we also check `gDist.Alpha(1)` and `gDist.Beta(1)`?

> +    for(size_t i = 0; i < N; i++)
+    {
+      if(i%2 == 0)
+        rdata(j, i) = dist(generator);
+      else
+        rdata(j, i) = dist2(generator);
+    }
+  }
+
+  for(size_t i = 0; i<N; i++)
+  {
+    if(i%2 == 0)
+      probabilities(i) = low_prob(generator);
+    else
+      probabilities(i) = high_prob(generator);
+  }

Another possibility to simplify the creation here is just to put the first N/2 points in the first part of `rdata` from `low_prob` and then the second N/2 points in the second part of `rdata` from `high_prob`, then call `arma::shuffle()` to shuffle the dataset.

-- 
You are receiving this because you commented.
Reply to this email directly or view it on GitHub:
https://github.com/mlpack/mlpack/pull/834#pullrequestreview-14055423
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://knife.lugatgt.org/pipermail/mlpack/attachments/20161221/c900ef00/attachment-0001.html>


More information about the mlpack mailing list