dynamic_sampler.hh 5.97 KB
Newer Older
1 2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2018 Tiago de Paula Peixoto <tiago@skewed.de>
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

#ifndef DYNAMIC_SAMPLER_HH
#define DYNAMIC_SAMPLER_HH

#include "random.hh"
#include <functional>
#include <boost/mpl/if.hpp>

namespace graph_tool
{
using namespace std;
using namespace boost;

template <class Value>
class DynamicSampler
{
public:
34
    DynamicSampler() : _back(0), _n_items(0) {}
35 36 37

    DynamicSampler(const vector<Value>& items,
                   const vector<double>& probs)
38
        : _back(0), _n_items(0)
39 40 41 42 43 44 45
    {
        for (size_t i = 0; i < items.size(); ++i)
            insert(items[i], probs[i]);
    }

    typedef Value value_type;

46 47 48
    size_t get_left(size_t i)   const { return 2 * i + 1;               }
    size_t get_right(size_t i)  const { return 2 * i + 2;               }
    size_t get_parent(size_t i) const { return i > 0 ? (i - 1) / 2 : 0; }
49 50

    template <class RNG>
51
    const Value& sample(RNG& rng) const
52
    {
53 54
        uniform_real_distribution<> sample(0, _tree[0]);
        double u = sample(rng), c = 0;
55 56

        size_t pos = 0;
57
        while (_idx[pos] == numeric_limits<size_t>::max())
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        {
            size_t l = get_left(pos);
            double a = _tree[l];
            if (u < a + c)
            {
                pos = l;
            }
            else
            {
                pos = get_right(pos);
                c += a;
            }
        }
        size_t i = _idx[pos];
        return _items[i];
    }

    size_t insert(const Value& v, double w)
    {
        size_t pos;
        if (_free.empty())
        {
            if (_back > 0)
            {
                // move parent to left leaf
                pos = get_parent(_back);
                size_t l = get_left(pos);
                _idx[l] = _idx[pos];
                _ipos[_idx[l]] = l;
                _tree[l] = _tree[pos];
88
                _idx[pos] = numeric_limits<size_t>::max();
89 90 91 92 93 94 95 96 97 98

                // position new item to the right
                _back = get_right(pos);
            }

            pos = _back;
            check_size(pos);

            _idx[pos] = _items.size();
            _items.push_back(v);
99
            _valid.push_back(true);
100 101 102 103 104 105 106 107
            _ipos.push_back(pos);
            _tree[pos] = w;
            _back++;
            check_size(_back);
        }
        else
        {
            pos = _free.back();
108 109 110
            auto i = _idx[pos];
            _items[i] = v;
            _valid[i] = true;
111 112 113 114 115
            _tree[pos] = w;
            _free.pop_back();
        }

        insert_leaf_prob(pos);
116
        _n_items++;
117 118 119 120 121 122 123 124
        return _idx[pos];
    }

    void remove(size_t i)
    {
        size_t pos = _ipos[i];
        remove_leaf_prob(pos);
        _free.push_back(pos);
125 126 127
        _items[i] = Value();
        _valid[i] = false;
        _n_items--;
128 129
    }

130
    void clear(bool shrink=false)
131 132
    {
        _items.clear();
133
        _ipos.clear();
134 135 136
        _tree.clear();
        _idx.clear();
        _free.clear();
137 138 139 140 141 142 143 144 145 146 147
        _valid.clear();
        if (shrink)
        {
            _items.shrink_to_fit();
            _ipos.shrink_to_fit();
            _tree.shrink_to_fit();
            _idx.shrink_to_fit();
            _free.shrink_to_fit();
            _valid.shrink_to_fit();
        }
        _back = 0;
148
        _n_items = 0;
149 150 151 152 153 154 155 156 157
    }

    void rebuild()
    {
        vector<Value> items;
        vector<double> probs;

        for (size_t i = 0; i < _tree.size(); ++i)
        {
158
            if (_idx[i] == numeric_limits<size_t>::max())
159
                continue;
160 161 162 163
            size_t j = _idx[i];
            if (!_valid[j])
                continue;
            items.push_back(_items[j]);
164 165 166
            probs.push_back(_tree[i]);
        }

167
        clear(true);
168 169 170 171 172

        for (size_t i = 0; i < items.size(); ++i)
            insert(items[i], probs[i]);
    }

173 174 175 176 177
    const Value& operator[](size_t i) const
    {
        return _items[i];
    }

178
    bool is_valid(size_t i) const
179
    {
180
        return ((i < _items.size()) && _valid[i]);
181 182
    }

183 184 185 186 187
    const auto& items() const
    {
        return _items;
    }

188
    auto begin() const
189 190 191 192
    {
        return _items.begin();
    }

193
    auto end() const
194 195 196 197
    {
        return _items.end();
    }

198 199 200 201
    size_t size() const
    {
        return _items.size();
    }
202 203 204

    bool empty() const
    {
205
        return _n_items == 0;
206
    }
207

208 209 210 211 212 213
private:

    void check_size(size_t i)
    {
        if (i >= _tree.size())
        {
214
            _idx.resize(i + 1, numeric_limits<size_t>::max());
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
            _tree.resize(i + 1, 0);
        }
    }

    void remove_leaf_prob(size_t i)
    {
        size_t parent = i;
        double w = _tree[i];
        while (parent > 0)
        {
            parent = get_parent(parent);
            _tree[parent] -= w;
        }
        _tree[i] = 0;
    }

    void insert_leaf_prob(size_t i)
    {
        size_t parent = i;
        double w = _tree[i];

        while (parent > 0)
        {
            parent = get_parent(parent);
            _tree[parent] += w;
        }
    }


244
    vector<Value>  _items;
245
    vector<size_t> _ipos;  // position of the item in the tree
246 247

    vector<double> _tree;  // tree nodes with weight sums
248
    vector<size_t> _idx;   // index in _items
249 250
    int _back;             // last item in tree

251 252
    vector<size_t> _free;  // empty leafs
    vector<bool> _valid;   // non-removed items
253
    size_t _n_items;
254 255 256 257 258 259 260
};



} // namespace graph_tool

#endif // DYNAMIC_SAMPLER_HH