sampler.hh 4.94 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2020 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
#include <functional>
23
#include <boost/mpl/if.hpp>
24
25
26
27
28
29

namespace graph_tool
{
using namespace std;
using namespace boost;

30
31
32
// Discrete sampling via vose's alias method.

// See http://www.keithschwarz.com/darts-dice-coins/ for a very clear
33
// explanation.
34

35
template <class Value, class KeepReference = mpl::true_>
36
37
38
class Sampler
{
public:
39
40
    Sampler(const vector<Value>& items,
            const vector<double>& probs)
41
42
        : _items(items), _probs(probs), _alias(items.size()),
          _S(0)
43
    {
44
        for (size_t i = 0; i < _probs.size(); ++i)
45
            _S += _probs[i];
46

47
48
49
        vector<size_t> small;
        vector<size_t> large;

50
        for (size_t i = 0; i < _probs.size(); ++i)
51
        {
52
            _probs[i] *= _probs.size() / _S;
53
            if (_probs[i] < 1)
54
                small.push_back(i);
55
            else
56
                large.push_back(i);
57
58
        }

59
        while (!(small.empty() || large.empty()))
60
        {
61
62
63
64
            size_t l = small.back();
            size_t g = large.back();
            small.pop_back();
            large.pop_back();
65
66
67
68

            _alias[l] = g;
            _probs[g] = (_probs[l] + _probs[g]) - 1;
            if (_probs[g] < 1)
69
                small.push_back(g);
70
            else
71
                large.push_back(g);
72
        }
73

74
        // fix numerical instability
75
76
77
78
        for (size_t i = 0; i < large.size(); ++i)
            _probs[large[i]] = 1;
        for (size_t i = 0; i < small.size(); ++i)
            _probs[small[i]] = 1;
79
80

        _sample = uniform_int_distribution<size_t>(0, _probs.size() - 1);
81
82
    }

83
84
    Sampler() {}

85
86
    template <class RNG>
    const Value& sample(RNG& rng)
87
    {
88
        size_t i = _sample(rng);
Tiago Peixoto's avatar
Tiago Peixoto committed
89
        bernoulli_distribution coin(_probs[i]);
90
91
92
93
        if (coin(rng))
            return _items[i];
        else
            return _items[_alias[i]];
94
95
    }

96
    size_t size() const { return _items.size(); }
97
    bool empty() const { return _S == 0; }
98
99
    bool has_n(size_t n) const { return (n == 0 || !empty()); }
    double prob_sum() const { return _S; }
100

101
102
103
104
105
106
107
108
109
110
    const Value& operator[](size_t i) const
    {
        return _items[i];
    }

    const auto& items() const
    {
        return _items;
    }

111
    auto begin() const
112
113
114
115
    {
        return _items.begin();
    }

116
    auto end() const
117
118
119
120
    {
        return _items.end();
    }

121
private:
122

123
124
125
126
    typedef typename mpl::if_<KeepReference,
                              const vector<Value>&,
                              vector<Value> >::type items_t;
    items_t _items;
127
128
    vector<double> _probs;
    vector<size_t> _alias;
129
    uniform_int_distribution<size_t> _sample;
130
    double _S;
131
    size_t _size;
132
};
133

134
135
// uniform sampling from containers

136
template <class Iter, class RNG>
137
auto uniform_sample_iter(Iter begin, const Iter& end, RNG& rng)
138
{
139
    auto N = std::distance(begin, end);
140
141
    std::uniform_int_distribution<size_t> i_rand(0, N - 1);
    std::advance(begin, i_rand(rng));
142
143
144
145
146
147
148
149
150
151
    return begin;
}

template <class Container, class RNG>
auto uniform_sample_iter(Container& v, RNG& rng)
{
    return uniform_sample_iter(v.begin(), v.end(), rng);
}

template <class Iter, class RNG>
152
153
typename std::iterator_traits<Iter>::reference
uniform_sample(const Iter& begin, const Iter& end, RNG& rng)
154
{
155
156
    auto iter = uniform_sample_iter(begin, end, rng);
    return *iter;
157
158
}

159
template <class Container, class RNG>
160
auto&& uniform_sample(Container& v, RNG& rng)
161
{
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    return *uniform_sample_iter(v, rng);
}

template <class Graph, class RNG>
typename boost::graph_traits<Graph>::vertex_descriptor
random_out_neighbor(typename boost::graph_traits<Graph>::vertex_descriptor v,
                    const Graph& g,
                    RNG& rng)
{
    auto iter = out_edges(v, g);
    return target(*uniform_sample_iter(iter.first, iter.second, rng), g);
}

template <class Graph, class RNG>
typename boost::graph_traits<Graph>::vertex_descriptor
random_in_neighbor(typename boost::graph_traits<Graph>::vertex_descriptor v,
                   const Graph& g,
                   RNG& rng)
{
    auto iter = in_edge_iteratorS<Graph>::get_edges(v, g);
    return source(*uniform_sample_iter(iter.first, iter.second, rng), g);
183
}
184
185
186
187

} // namespace graph_tool

#endif // SAMPLER_HH