sampler.hh 6.88 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2007-2012 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21
#include "random.hh"
22
23
24
25
26
27
28
29
30
31
32
33
#include <iostream>

namespace graph_tool
{
using namespace std;
using namespace boost;

// utility class to sample uniformly from a collection of values
template <class ValueType>
class Sampler
{
public:
34
    Sampler() {}
35
36

    template <class Iterator>
37
    Sampler(Iterator iter, Iterator end)
38
39
    {
        for(; iter != end; ++iter)
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
            Insert(*iter);
    }

    void Insert(const ValueType& v)
    {
        _candidates.push_back(v);
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
    }

    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, back;
        iter = _candidates_set.find(v);

        if (iter == _candidates_set.end())
            return;

        back = _candidates_set.find(_candidates.back());

        size_t index = iter->second;
        swap(_candidates[index], _candidates.back());
        _candidates.pop_back();

        if (!_candidates.empty() && back != iter)
71
        {
72
73
            _candidates_set.erase(back);
            _candidates_set.insert(make_pair(_candidates[index], index));
74
        }
75
        _candidates_set.erase(iter);
76
77
    }

78
    bool Empty()
79
    {
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

    ValueType operator()(rng_t& rng, bool remove = false)
    {
        //assert(!_candidates.empty());
        tr1::uniform_int<> sample(0, _candidates.size() - 1);
        int i = sample(rng);
        if (remove)
94
        {
95
96
97
98
99
100
101
102
            swap(_candidates[i], _candidates.back());
            ValueType ret = _candidates.back();
            _candidates.pop_back();
            return ret;
        }
        else
        {
            return _candidates[i];
103
104
105
        }
    }

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
private:
    vector<ValueType> _candidates;
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
};


template <class ValueType>
class WeightedSampler
{
public:

    void Insert(const ValueType& v, double p)
    {
        _candidates.push_back(make_pair(v, p));
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
        _erased.push_back(false);
        _rebuild = true;
    }

126
127
128
129
130
131
132
133
134
135
136
    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end, temp;
        tie(iter, end) = _candidates_set.equal_range(v);
137
138
139
     
        if (iter == end)
            return;
140

141
        while(_erased[iter->second])
142
        {
143
144
145
146
            temp = iter++;
            _candidates_set.erase(temp);
            if (iter == end)
                return;
147
148
        }

149
150
151
152
153
        size_t index = iter->second;
        _erased[index] = true;
        _erased_prob += _candidates[index].second;
        if (_erased_prob >= 0.3)
            _rebuild = true;
154
155
156
157
158
159
160
161
162
163
164
165
    }

    bool Empty()
    {
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

166
    void BuildTable()
167
    {
168
169
170
        // remove possibly erased elements
        size_t i = 0;
        while (i < _candidates.size())
171
        {
172
            if (_erased[i])
173
174
            {
                swap(_candidates[i], _candidates.back());
175
                swap(_erased[i], _erased.back());
176
                _candidates.pop_back();
177
                _erased.pop_back();
178
179
180
            }
            else
            {
181
                ++i;
182
183
            }
        }
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        _erased_prob = 0;

        vector<pair<size_t, double> > remainder;
        _alias.resize(_candidates.size());

        double P_sum = 0;
        for (size_t i = 0; i < _candidates.size(); ++i)
            P_sum += _candidates[i].second;

        
        size_t N = _candidates.size();
        double P = 1.0 / N;
        for (size_t i = 0; i < _candidates.size(); ++i)
        {
            _candidates[i].second /= P_sum;
            double pi = _candidates[i].second;
            if (pi > P)
                remainder.push_back(make_pair(i, pi - P));
            _alias[i] = make_pair(i, .1);
        }


        for (size_t i = 0; i < _candidates.size(); ++i)
207
        {
208
209
            double pi = _candidates[i].second;
            if (pi < P)
210
            {
211
                for (size_t j = 0; j < remainder.size(); ++j)
212
                {
213
                    if (remainder[j].second >= P - pi)
214
                    {
215
216
217
                        _alias[i] = make_pair(remainder[j].first, pi * N);
                        remainder[j].second -= P - pi;
                        if (remainder[j].second <= 0)
218
                        {
219
220
                            swap(remainder[j], remainder.back());
                            remainder.pop_back();
221
                        }
222
                        break;
223
224
225
226
                    }
                }
            }
        }
227
        _rebuild = false;
228
    }
229

230

231
    ValueType operator()(rng_t& rng, bool remove = false)
232
    {
233
234
        if (_rebuild)
            BuildTable();
235

236
237
238
239
240
241
242
243
        tr1::variate_generator<rng_t&, tr1::uniform_real<> >
            sample(rng, tr1::uniform_real<>(0.0, 1.0));
        size_t i;
        do
        {
            double r = sample() * _candidates.size();
            i = floor(r);       // in [0, n-1]
            double x = r - i;   // in [0, 1)
244

245
246
247
248
            if (x > _alias[i].second)
                i = _alias[i].first;
        }
        while (_erased[i]);
249

250
251
252
253
254
255
        if (remove)
        {
            _erased[i] = true;
            _erased_prob += _candidates[i].second;
            if (_erased_prob >= 0.3)
                _rebuild = true;
256
        }
257
        return _candidates[i].first;
258
259
260
    }

private:
261
    vector<pair<ValueType, double> > _candidates;
262
263
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
264
    vector<pair<size_t, double> > _alias;
265
    vector<uint8_t> _erased;
266
267
    bool _erased_prob;
    bool _rebuild;
268
269
270
271
272
};

} // namespace graph_tool

#endif // SAMPLER_HH