sampler.hh 3.09 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2013 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
#include <functional>
23
24
25
26
27
28

namespace graph_tool
{
using namespace std;
using namespace boost;

29
30
31
32
33
34
// Discrete sampling via vose's alias method.

// See http://www.keithschwarz.com/darts-dice-coins/ for a very clear
// explanation,

template <class Value>
35
36
37
class Sampler
{
public:
38
39
40
    Sampler(const vector<Value>& items,
            const vector<double>& probs)
        : _items(items), _probs(probs), _alias(items.size())
41
    {
42
43
44
        double S = 0;
        for (size_t i = 0; i < _probs.size(); ++i)
            S += _probs[i];
45

46
        for (size_t i = 0; i < _probs.size(); ++i)
47
        {
48
49
50
51
52
            _probs[i] *= _probs.size() / S;
            if (_probs[i] < 1)
                _small.push_back(i);
            else
                _large.push_back(i);
53
54
        }

55
        while (!(_small.empty() || _large.empty()))
56
        {
57
58
59
60
61
62
63
64
65
66
67
            size_t l = _small.back();
            size_t g = _large.back();
            _small.pop_back();
            _large.pop_back();

            _alias[l] = g;
            _probs[g] = (_probs[l] + _probs[g]) - 1;
            if (_probs[g] < 1)
                _small.push_back(g);
            else
                _large.push_back(g);
68
        }
69

70
71
72
73
74
75
76
        // fix numerical instability
        for (size_t i = 0; i < _large.size(); ++i)
            _probs[_large[i]] = 1;
        for (size_t i = 0; i < _small.size(); ++i)
            _probs[_small[i]] = 1;
        _large.clear();
        _small.clear();
77
78
    }

79
80
    template <class RNG>
    const Value& sample(RNG& rng)
81
    {
82
83
        tr1::uniform_int<> sample(0, _probs.size() - 1);
        size_t i = sample(rng);
84

85
86
87
88
89
        tr1::bernoulli_distribution coin(_probs[i]);
        if (coin(rng))
            return _items[i];
        else
            return _items[_alias[i]];
90
91
    }

92
private:
93

94
    struct _cmp : binary_function <size_t, size_t, bool>
95
    {
96
97
98
        _cmp(const vector<double>& prob):_prob(prob) {}
        const vector<double>& _prob;
        bool operator() (const size_t& x, const size_t& y) const
99
        {
100
101
102
            if (_prob[x] == _prob[y])
                return x < y;
            return _prob[x] < _prob[y];
103
        }
104
    };
105

106
107
108
109
110
111
    const vector<Value>& _items;
    vector<double> _probs;
    vector<size_t> _alias;
    vector<size_t> _small;
    vector<size_t> _large;
};
112
113
114
115
116
117



} // namespace graph_tool

#endif // SAMPLER_HH