sampler.hh 6.99 KB
Newer Older
1 2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2007-2012 Tiago de Paula Peixoto <tiago@skewed.de>
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef SAMPLER_HH
#define SAMPLER_HH

21 22 23 24 25
#if (GCC_VERSION >= 40400)
#   include <tr1/random>
#else
#   include <boost/tr1/random.hpp>
#endif
26 27 28 29 30 31 32 33 34 35 36 37 38 39
#include <iostream>

namespace graph_tool
{
using namespace std;
using namespace boost;

typedef tr1::mt19937 rng_t;

// utility class to sample uniformly from a collection of values
template <class ValueType>
class Sampler
{
public:
40
    Sampler() {}
41 42

    template <class Iterator>
43
    Sampler(Iterator iter, Iterator end)
44 45
    {
        for(; iter != end; ++iter)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
            Insert(*iter);
    }

    void Insert(const ValueType& v)
    {
        _candidates.push_back(v);
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
    }

    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, back;
        iter = _candidates_set.find(v);

        if (iter == _candidates_set.end())
            return;

        back = _candidates_set.find(_candidates.back());

        size_t index = iter->second;
        swap(_candidates[index], _candidates.back());
        _candidates.pop_back();

        if (!_candidates.empty() && back != iter)
77
        {
78 79
            _candidates_set.erase(back);
            _candidates_set.insert(make_pair(_candidates[index], index));
80
        }
81
        _candidates_set.erase(iter);
82 83
    }

84
    bool Empty()
85
    {
86 87 88 89 90 91 92 93 94 95 96 97 98 99
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

    ValueType operator()(rng_t& rng, bool remove = false)
    {
        //assert(!_candidates.empty());
        tr1::uniform_int<> sample(0, _candidates.size() - 1);
        int i = sample(rng);
        if (remove)
100
        {
101 102 103 104 105 106 107 108
            swap(_candidates[i], _candidates.back());
            ValueType ret = _candidates.back();
            _candidates.pop_back();
            return ret;
        }
        else
        {
            return _candidates[i];
109 110 111
        }
    }

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
private:
    vector<ValueType> _candidates;
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
};


template <class ValueType>
class WeightedSampler
{
public:

    void Insert(const ValueType& v, double p)
    {
        _candidates.push_back(make_pair(v, p));
        _candidates_set.insert(make_pair(v, _candidates.size() - 1));
        _erased.push_back(false);
        _rebuild = true;
    }

132 133 134 135 136 137 138 139 140 141 142
    bool HasValue(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end;
        tie(iter, end) = _candidates_set.equal_range(v);
        return (iter != end);
    }

    void Remove(const ValueType& v)
    {
        typeof(_candidates_set.begin()) iter, end, temp;
        tie(iter, end) = _candidates_set.equal_range(v);
143 144 145
     
        if (iter == end)
            return;
146

147
        while(_erased[iter->second])
148
        {
149 150 151 152
            temp = iter++;
            _candidates_set.erase(temp);
            if (iter == end)
                return;
153 154
        }

155 156 157 158 159
        size_t index = iter->second;
        _erased[index] = true;
        _erased_prob += _candidates[index].second;
        if (_erased_prob >= 0.3)
            _rebuild = true;
160 161 162 163 164 165 166 167 168 169 170 171
    }

    bool Empty()
    {
        return _candidates.empty();
    }

    size_t Size()
    {
        return _candidates.size();
    }

172
    void BuildTable()
173
    {
174 175 176
        // remove possibly erased elements
        size_t i = 0;
        while (i < _candidates.size())
177
        {
178
            if (_erased[i])
179 180
            {
                swap(_candidates[i], _candidates.back());
181
                swap(_erased[i], _erased.back());
182
                _candidates.pop_back();
183
                _erased.pop_back();
184 185 186
            }
            else
            {
187
                ++i;
188 189
            }
        }
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
        _erased_prob = 0;

        vector<pair<size_t, double> > remainder;
        _alias.resize(_candidates.size());

        double P_sum = 0;
        for (size_t i = 0; i < _candidates.size(); ++i)
            P_sum += _candidates[i].second;

        
        size_t N = _candidates.size();
        double P = 1.0 / N;
        for (size_t i = 0; i < _candidates.size(); ++i)
        {
            _candidates[i].second /= P_sum;
            double pi = _candidates[i].second;
            if (pi > P)
                remainder.push_back(make_pair(i, pi - P));
            _alias[i] = make_pair(i, .1);
        }


        for (size_t i = 0; i < _candidates.size(); ++i)
213
        {
214 215
            double pi = _candidates[i].second;
            if (pi < P)
216
            {
217
                for (size_t j = 0; j < remainder.size(); ++j)
218
                {
219
                    if (remainder[j].second >= P - pi)
220
                    {
221 222 223
                        _alias[i] = make_pair(remainder[j].first, pi * N);
                        remainder[j].second -= P - pi;
                        if (remainder[j].second <= 0)
224
                        {
225 226
                            swap(remainder[j], remainder.back());
                            remainder.pop_back();
227
                        }
228
                        break;
229 230 231 232
                    }
                }
            }
        }
233
        _rebuild = false;
234
    }
235
    
236

237
    ValueType operator()(rng_t& rng, bool remove = false)
238
    {
239 240
        if (_rebuild)
            BuildTable();
241

242 243 244 245 246 247 248 249
        tr1::variate_generator<rng_t&, tr1::uniform_real<> >
            sample(rng, tr1::uniform_real<>(0.0, 1.0));
        size_t i;
        do
        {
            double r = sample() * _candidates.size();
            i = floor(r);       // in [0, n-1]
            double x = r - i;   // in [0, 1)
250

251 252 253 254
            if (x > _alias[i].second)
                i = _alias[i].first;
        }
        while (_erased[i]);
255

256 257 258 259 260 261
        if (remove)
        {
            _erased[i] = true;
            _erased_prob += _candidates[i].second;
            if (_erased_prob >= 0.3)
                _rebuild = true;
262
        }
263
        return _candidates[i].first;
264 265 266
    }

private:
267
    vector<pair<ValueType, double> > _candidates;
268 269
    tr1::unordered_multimap<ValueType, size_t, hash<ValueType> >
        _candidates_set;
270
    vector<pair<size_t, double> > _alias;
271
    vector<uint8_t> _erased;
272 273
    bool _erased_prob;
    bool _rebuild;
274 275 276 277 278
};

} // namespace graph_tool

#endif // SAMPLER_HH