sampler.hh 4.92 KB
Newer Older
1 2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2020 Tiago de Paula Peixoto <tiago@skewed.de>
4
//
5 6 7 8
// This program is free software; you can redistribute it and/or modify it under
// the terms of the GNU Lesser General Public License as published by the Free
// Software Foundation; either version 3 of the License, or (at your option) any
// later version.
9
//
10 11 12 13
// This program is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
// details.
14
//
15
// You should have received a copy of the GNU Lesser General Public License
16 17 18 19 20
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
#include <boost/mpl/if.hpp>
23 24 25 26 27 28

namespace graph_tool
{
using namespace std;
using namespace boost;

29 30 31
// Discrete sampling via vose's alias method.

// See http://www.keithschwarz.com/darts-dice-coins/ for a very clear
32
// explanation.
33

34
template <class Value, class KeepReference = mpl::true_>
35 36 37
class Sampler
{
public:
38 39
    Sampler(const vector<Value>& items,
            const vector<double>& probs)
40 41
        : _items(items), _probs(probs), _alias(items.size()),
          _S(0)
42
    {
43
        for (size_t i = 0; i < _probs.size(); ++i)
44
            _S += _probs[i];
45

46 47 48
        vector<size_t> small;
        vector<size_t> large;

49
        for (size_t i = 0; i < _probs.size(); ++i)
50
        {
51
            _probs[i] *= _probs.size() / _S;
52
            if (_probs[i] < 1)
53
                small.push_back(i);
54
            else
55
                large.push_back(i);
56 57
        }

58
        while (!(small.empty() || large.empty()))
59
        {
60 61 62 63
            size_t l = small.back();
            size_t g = large.back();
            small.pop_back();
            large.pop_back();
64 65 66 67

            _alias[l] = g;
            _probs[g] = (_probs[l] + _probs[g]) - 1;
            if (_probs[g] < 1)
68
                small.push_back(g);
69
            else
70
                large.push_back(g);
71
        }
72

73
        // fix numerical instability
74 75 76 77
        for (size_t i = 0; i < large.size(); ++i)
            _probs[large[i]] = 1;
        for (size_t i = 0; i < small.size(); ++i)
            _probs[small[i]] = 1;
78 79

        _sample = uniform_int_distribution<size_t>(0, _probs.size() - 1);
80 81
    }

82 83
    Sampler() {}

84 85
    template <class RNG>
    const Value& sample(RNG& rng)
86
    {
87
        size_t i = _sample(rng);
Tiago Peixoto's avatar
Tiago Peixoto committed
88
        bernoulli_distribution coin(_probs[i]);
89 90 91 92
        if (coin(rng))
            return _items[i];
        else
            return _items[_alias[i]];
93 94
    }

95
    size_t size() const { return _items.size(); }
96
    bool empty() const { return _S == 0; }
97 98
    bool has_n(size_t n) const { return (n == 0 || !empty()); }
    double prob_sum() const { return _S; }
99

100 101 102 103 104 105 106 107 108 109
    const Value& operator[](size_t i) const
    {
        return _items[i];
    }

    const auto& items() const
    {
        return _items;
    }

110
    auto begin() const
111 112 113 114
    {
        return _items.begin();
    }

115
    auto end() const
116 117 118 119
    {
        return _items.end();
    }

120
private:
121

122 123 124 125
    typedef typename mpl::if_<KeepReference,
                              const vector<Value>&,
                              vector<Value> >::type items_t;
    items_t _items;
126 127
    vector<double> _probs;
    vector<size_t> _alias;
128
    uniform_int_distribution<size_t> _sample;
129
    double _S;
130
};
131

132 133
// uniform sampling from containers

134
template <class Iter, class RNG>
135
auto uniform_sample_iter(Iter begin, const Iter& end, RNG& rng)
136
{
137
    auto N = std::distance(begin, end);
138 139
    std::uniform_int_distribution<size_t> i_rand(0, N - 1);
    std::advance(begin, i_rand(rng));
140 141 142 143 144 145 146 147 148 149
    return begin;
}

template <class Container, class RNG>
auto uniform_sample_iter(Container& v, RNG& rng)
{
    return uniform_sample_iter(v.begin(), v.end(), rng);
}

template <class Iter, class RNG>
150 151
typename std::iterator_traits<Iter>::reference
uniform_sample(const Iter& begin, const Iter& end, RNG& rng)
152
{
153 154
    auto iter = uniform_sample_iter(begin, end, rng);
    return *iter;
155 156
}

157
template <class Container, class RNG>
158
auto&& uniform_sample(Container& v, RNG& rng)
159
{
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
    return *uniform_sample_iter(v, rng);
}

template <class Graph, class RNG>
typename boost::graph_traits<Graph>::vertex_descriptor
random_out_neighbor(typename boost::graph_traits<Graph>::vertex_descriptor v,
                    const Graph& g,
                    RNG& rng)
{
    auto iter = out_edges(v, g);
    return target(*uniform_sample_iter(iter.first, iter.second, rng), g);
}

template <class Graph, class RNG>
typename boost::graph_traits<Graph>::vertex_descriptor
random_in_neighbor(typename boost::graph_traits<Graph>::vertex_descriptor v,
                   const Graph& g,
                   RNG& rng)
{
    auto iter = in_edge_iteratorS<Graph>::get_edges(v, g);
    return source(*uniform_sample_iter(iter.first, iter.second, rng), g);
181
}
182 183 184 185

} // namespace graph_tool

#endif // SAMPLER_HH