sampler.hh 3.31 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2013 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
#include <functional>
23
#include <boost/mpl/if.hpp>
24
25
26
27
28
29

namespace graph_tool
{
using namespace std;
using namespace boost;

30
31
32
33
34
// Discrete sampling via vose's alias method.

// See http://www.keithschwarz.com/darts-dice-coins/ for a very clear
// explanation,

35
template <class Value, class KeepReference = mpl::true_>
36
37
38
class Sampler
{
public:
39
40
41
    Sampler(const vector<Value>& items,
            const vector<double>& probs)
        : _items(items), _probs(probs), _alias(items.size())
42
    {
43
44
45
        double S = 0;
        for (size_t i = 0; i < _probs.size(); ++i)
            S += _probs[i];
46

47
        for (size_t i = 0; i < _probs.size(); ++i)
48
        {
49
50
51
52
53
            _probs[i] *= _probs.size() / S;
            if (_probs[i] < 1)
                _small.push_back(i);
            else
                _large.push_back(i);
54
55
        }

56
        while (!(_small.empty() || _large.empty()))
57
        {
58
59
60
61
62
63
64
65
66
67
68
            size_t l = _small.back();
            size_t g = _large.back();
            _small.pop_back();
            _large.pop_back();

            _alias[l] = g;
            _probs[g] = (_probs[l] + _probs[g]) - 1;
            if (_probs[g] < 1)
                _small.push_back(g);
            else
                _large.push_back(g);
69
        }
70

71
72
73
74
75
76
77
        // fix numerical instability
        for (size_t i = 0; i < _large.size(); ++i)
            _probs[_large[i]] = 1;
        for (size_t i = 0; i < _small.size(); ++i)
            _probs[_small[i]] = 1;
        _large.clear();
        _small.clear();
78
79
    }

80
81
    Sampler() {}

82
83
    template <class RNG>
    const Value& sample(RNG& rng)
84
    {
85
86
        tr1::uniform_int<> sample(0, _probs.size() - 1);
        size_t i = sample(rng);
87

88
89
90
91
92
        tr1::bernoulli_distribution coin(_probs[i]);
        if (coin(rng))
            return _items[i];
        else
            return _items[_alias[i]];
93
94
    }

95
private:
96

97
    struct _cmp : binary_function <size_t, size_t, bool>
98
    {
99
100
101
        _cmp(const vector<double>& prob):_prob(prob) {}
        const vector<double>& _prob;
        bool operator() (const size_t& x, const size_t& y) const
102
        {
103
104
105
            if (_prob[x] == _prob[y])
                return x < y;
            return _prob[x] < _prob[y];
106
        }
107
    };
108

109
110
111
112
    typedef typename mpl::if_<KeepReference,
                              const vector<Value>&,
                              vector<Value> >::type items_t;
    items_t _items;
113
114
115
116
117
    vector<double> _probs;
    vector<size_t> _alias;
    vector<size_t> _small;
    vector<size_t> _large;
};
118
119
120
121
122
123



} // namespace graph_tool

#endif // SAMPLER_HH