sampler.hh 3.34 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2016 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
#include <functional>
23
#include <boost/mpl/if.hpp>
24
25
26
27
28
29

namespace graph_tool
{
using namespace std;
using namespace boost;

30
31
32
// Discrete sampling via vose's alias method.

// See http://www.keithschwarz.com/darts-dice-coins/ for a very clear
33
// explanation.
34

35
template <class Value, class KeepReference = mpl::true_>
36
37
38
class Sampler
{
public:
39
40
41
    Sampler(const vector<Value>& items,
            const vector<double>& probs)
        : _items(items), _probs(probs), _alias(items.size())
42
    {
43
44
45
        double S = 0;
        for (size_t i = 0; i < _probs.size(); ++i)
            S += _probs[i];
46

47
        for (size_t i = 0; i < _probs.size(); ++i)
48
        {
49
50
51
52
53
            _probs[i] *= _probs.size() / S;
            if (_probs[i] < 1)
                _small.push_back(i);
            else
                _large.push_back(i);
54
55
        }

56
        while (!(_small.empty() || _large.empty()))
57
        {
58
59
60
61
62
63
64
65
66
67
68
            size_t l = _small.back();
            size_t g = _large.back();
            _small.pop_back();
            _large.pop_back();

            _alias[l] = g;
            _probs[g] = (_probs[l] + _probs[g]) - 1;
            if (_probs[g] < 1)
                _small.push_back(g);
            else
                _large.push_back(g);
69
        }
70

71
72
73
74
75
76
77
        // fix numerical instability
        for (size_t i = 0; i < _large.size(); ++i)
            _probs[_large[i]] = 1;
        for (size_t i = 0; i < _small.size(); ++i)
            _probs[_small[i]] = 1;
        _large.clear();
        _small.clear();
78
79

        _sample = uniform_int_distribution<size_t>(0, _probs.size() - 1);
80
81
    }

82
83
    Sampler() {}

84
85
    template <class RNG>
    const Value& sample(RNG& rng)
86
    {
87
        size_t i = _sample(rng);
Tiago Peixoto's avatar
Tiago Peixoto committed
88
        bernoulli_distribution coin(_probs[i]);
89
90
91
92
        if (coin(rng))
            return _items[i];
        else
            return _items[_alias[i]];
93
94
    }

95
    size_t size() const { return _items.size(); }
96
    bool empty() const { return _items.empty(); }
97

98
private:
99

100
101
102
103
    typedef typename mpl::if_<KeepReference,
                              const vector<Value>&,
                              vector<Value> >::type items_t;
    items_t _items;
104
105
106
107
    vector<double> _probs;
    vector<size_t> _alias;
    vector<size_t> _small;
    vector<size_t> _large;
108
109
    uniform_int_distribution<size_t> _sample;

110
};
111

112
113
114
// uniform sampling from containers

template <class Container, class RNG>
115
auto& uniform_sample(Container& v, RNG& rng)
116
117
118
119
{
    std::uniform_int_distribution<size_t> i_rand(0, v.size() - 1);
    return v[i_rand(rng)];
}
120
121
122
123
124


} // namespace graph_tool

#endif // SAMPLER_HH