// graph-tool -- a general graph modification and manipulation thingy // // Copyright (C) 2006-2013 Tiago de Paula Peixoto // // This program is free software; you can redistribute it and/or // modify it under the terms of the GNU General Public License // as published by the Free Software Foundation; either version 3 // of the License, or (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . #ifndef SAMPLER_HH #define SAMPLER_HH #include "random.hh" #include #include namespace graph_tool { using namespace std; using namespace boost; // Discrete sampling via vose's alias method. // See http://www.keithschwarz.com/darts-dice-coins/ for a very clear // explanation, template class Sampler { public: Sampler(const vector& items, const vector& probs) : _items(items), _probs(probs), _alias(items.size()) { double S = 0; for (size_t i = 0; i < _probs.size(); ++i) S += _probs[i]; for (size_t i = 0; i < _probs.size(); ++i) { _probs[i] *= _probs.size() / S; if (_probs[i] < 1) _small.push_back(i); else _large.push_back(i); } while (!(_small.empty() || _large.empty())) { size_t l = _small.back(); size_t g = _large.back(); _small.pop_back(); _large.pop_back(); _alias[l] = g; _probs[g] = (_probs[l] + _probs[g]) - 1; if (_probs[g] < 1) _small.push_back(g); else _large.push_back(g); } // fix numerical instability for (size_t i = 0; i < _large.size(); ++i) _probs[_large[i]] = 1; for (size_t i = 0; i < _small.size(); ++i) _probs[_small[i]] = 1; _large.clear(); _small.clear(); } Sampler() {} template const Value& sample(RNG& rng) { tr1::uniform_int<> sample(0, _probs.size() - 1); size_t i = sample(rng); tr1::bernoulli_distribution coin(_probs[i]); if (coin(rng)) return _items[i]; else return _items[_alias[i]]; } private: struct _cmp : binary_function { _cmp(const vector& prob):_prob(prob) {} const vector& _prob; bool operator() (const size_t& x, const size_t& y) const { if (_prob[x] == _prob[y]) return x < y; return _prob[x] < _prob[y]; } }; typedef typename mpl::if_&, vector >::type items_t; items_t _items; vector _probs; vector _alias; vector _small; vector _large; }; } // namespace graph_tool #endif // SAMPLER_HH