sampler.hh 6.99 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2007-2012 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
22
23
24
25
#if (GCC_VERSION >= 40400)
#   include <tr1/random>
#else
#   include <boost/tr1/random.hpp>
#endif
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <iostream>

namespace graph_tool
{
using namespace std;
using namespace boost;

typedef tr1::mt19937 rng_t;

// utility class to sample uniformly from a collection of values
template <class ValueType>
class Sampler
{
public:
40
    Sampler() {}
41
42

    template <class Iterator>
43
    Sampler(Iterator iter, Iterator end)
44
45
    {
        for(; iter != end; ++iter)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
            Insert(*iter);
    }

    void Insert(const ValueType& v)
    {
        _candidates.push_back(v);
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
    }

    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, back;
        iter = _candidates_set.find(v);

        if (iter == _candidates_set.end())
            return;

        back = _candidates_set.find(_candidates.back());

        size_t index = iter->second;
        swap(_candidates[index], _candidates.back());
        _candidates.pop_back();

        if (!_candidates.empty() && back != iter)
77
        {
78
79
            _candidates_set.erase(back);
            _candidates_set.insert(make_pair(_candidates[index], index));
80
        }
81
        _candidates_set.erase(iter);
82
83
    }

84
    bool Empty()
85
    {
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

    ValueType operator()(rng_t& rng, bool remove = false)
    {
        //assert(!_candidates.empty());
        tr1::uniform_int<> sample(0, _candidates.size() - 1);
        int i = sample(rng);
        if (remove)
100
        {
101
102
103
104
105
106
107
108
            swap(_candidates[i], _candidates.back());
            ValueType ret = _candidates.back();
            _candidates.pop_back();
            return ret;
        }
        else
        {
            return _candidates[i];
109
110
111
        }
    }

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
private:
    vector<ValueType> _candidates;
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
};


template <class ValueType>
class WeightedSampler
{
public:

    void Insert(const ValueType& v, double p)
    {
        _candidates.push_back(make_pair(v, p));
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
        _erased.push_back(false);
        _rebuild = true;
    }

132
133
134
135
136
137
138
139
140
141
142
    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end, temp;
        tie(iter, end) = _candidates_set.equal_range(v);
143
144
145
     
        if (iter == end)
            return;
146

147
        while(_erased[iter->second])
148
        {
149
150
151
152
            temp = iter++;
            _candidates_set.erase(temp);
            if (iter == end)
                return;
153
154
        }

155
156
157
158
159
        size_t index = iter->second;
        _erased[index] = true;
        _erased_prob += _candidates[index].second;
        if (_erased_prob >= 0.3)
            _rebuild = true;
160
161
162
163
164
165
166
167
168
169
170
171
    }

    bool Empty()
    {
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

172
    void BuildTable()
173
    {
174
175
176
        // remove possibly erased elements
        size_t i = 0;
        while (i < _candidates.size())
177
        {
178
            if (_erased[i])
179
180
            {
                swap(_candidates[i], _candidates.back());
181
                swap(_erased[i], _erased.back());
182
                _candidates.pop_back();
183
                _erased.pop_back();
184
185
186
            }
            else
            {
187
                ++i;
188
189
            }
        }
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        _erased_prob = 0;

        vector<pair<size_t, double> > remainder;
        _alias.resize(_candidates.size());

        double P_sum = 0;
        for (size_t i = 0; i < _candidates.size(); ++i)
            P_sum += _candidates[i].second;

        
        size_t N = _candidates.size();
        double P = 1.0 / N;
        for (size_t i = 0; i < _candidates.size(); ++i)
        {
            _candidates[i].second /= P_sum;
            double pi = _candidates[i].second;
            if (pi > P)
                remainder.push_back(make_pair(i, pi - P));
            _alias[i] = make_pair(i, .1);
        }


        for (size_t i = 0; i < _candidates.size(); ++i)
213
        {
214
215
            double pi = _candidates[i].second;
            if (pi < P)
216
            {
217
                for (size_t j = 0; j < remainder.size(); ++j)
218
                {
219
                    if (remainder[j].second >= P - pi)
220
                    {
221
222
223
                        _alias[i] = make_pair(remainder[j].first, pi * N);
                        remainder[j].second -= P - pi;
                        if (remainder[j].second <= 0)
224
                        {
225
226
                            swap(remainder[j], remainder.back());
                            remainder.pop_back();
227
                        }
228
                        break;
229
230
231
232
                    }
                }
            }
        }
233
        _rebuild = false;
234
    }
235
    
236

237
    ValueType operator()(rng_t& rng, bool remove = false)
238
    {
239
240
        if (_rebuild)
            BuildTable();
241

242
243
244
245
246
247
248
249
        tr1::variate_generator<rng_t&, tr1::uniform_real<> >
            sample(rng, tr1::uniform_real<>(0.0, 1.0));
        size_t i;
        do
        {
            double r = sample() * _candidates.size();
            i = floor(r);       // in [0, n-1]
            double x = r - i;   // in [0, 1)
250

251
252
253
254
            if (x > _alias[i].second)
                i = _alias[i].first;
        }
        while (_erased[i]);
255

256
257
258
259
260
261
        if (remove)
        {
            _erased[i] = true;
            _erased_prob += _candidates[i].second;
            if (_erased_prob >= 0.3)
                _rebuild = true;
262
        }
263
        return _candidates[i].first;
264
265
266
    }

private:
267
    vector<pair<ValueType, double> > _candidates;
268
269
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
270
    vector<pair<size_t, double> > _alias;
271
    vector<uint8_t> _erased;
272
273
    bool _erased_prob;
    bool _rebuild;
274
275
276
277
278
};

} // namespace graph_tool

#endif // SAMPLER_HH