graph_neighbor_sampler.hh 6.27 KB
Newer Older
1
2
// graph-tool -- a general graph modification and manipulation thingy
//
Tiago Peixoto's avatar
Tiago Peixoto committed
3
// Copyright (C) 2006-2017 Tiago de Paula Peixoto <tiago@skewed.de>
4
5
6
7
8
9
10
11
12
13
14
15
16
17
//
// This program is free software; you can redistribute it and/or
// modify it under the terms of the GNU General Public License
// as published by the Free Software Foundation; either version 3
// of the License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

18
19
#ifndef GRAPH_NEIGHBOR_SAMPLER_HH
#define GRAPH_NEIGHBOR_SAMPLER_HH
20
21
22
23
24

#include "config.h"

#include "graph_tool.hh"

25
// Sample neighbors efficiently
26
27
28
29
30
// =============================

namespace graph_tool
{

31
template <class Graph, class Weighted, class Dynamic>
32
class NeighborSampler
33
34
35
36
37
{
public:
    typedef typename boost::graph_traits<Graph>::vertex_descriptor vertex_t;

    template <class Eprop>
38
    NeighborSampler(Graph& g, Eprop& eweight, bool self_loops=false)
39
40
41
        : _sampler(get(vertex_index_t(), g), num_vertices(g)),
          _sampler_pos(get(vertex_index_t(), g), num_vertices(g)),
          _eindex(get(edge_index_t(), g))
42
43
44
45
46
47
48
49
    {
        init(g, eweight, self_loops,
             typename boost::mpl::and_<Weighted,
                                       typename boost::mpl::not_<Dynamic>::type>::type());
    }

    template <class Eprop>
    void init(Graph& g, Eprop& eweight, bool self_loops, boost::mpl::false_)
50
    {
51
        for (auto e : edges_range(g))
52
        {
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
            auto u = source(e, g);
            auto v = target(e, g);

            if (!self_loops && u == v)
                continue;

            auto w = eweight[e];

            if (w == 0)
                continue;

            if (u == v)
            {
                insert(v, u, w, e);
            }
            else
69
            {
70
71
                insert(v, u, w, e);
                insert(u, v, w, e);
72
73
74
75
            }
        }
    }

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
    template <class Eprop>
    void init(Graph& g, Eprop& eweight, bool self_loops, boost::mpl::true_)
    {
        for (auto v : vertices_range(g))
        {
            std::vector<item_t> us;
            std::vector<double> probs;
            for (auto e : out_edges_range(v, g))
            {
                auto u = target(e, g);
                double w = eweight[e];
                if (w == 0)
                    continue;

                if (u == v)
                {
                    if (!self_loops)
                        continue;
94
                    if (!graph_tool::is_directed(g))
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
                        w /= 2;
                }
                us.emplace_back(u, 0);
                probs.push_back(w);
            }

            for (auto e : in_edges_range(v, g))
            {
                auto u = source(e, g);
                double w = eweight[e];
                if (w == 0 || u == v)
                    continue;
                us.emplace_back(u, 0);
                probs.push_back(w);
            }
            _sampler[v] = sampler_t(us, probs);
        }
    }

114
115
116
    template <class RNG>
    vertex_t sample(vertex_t v, RNG& rng)
    {
117
118
119
        auto& sampler = _sampler[v];
        auto& item = sample_item(sampler, rng);
        return item.first;
120
121
122
123
124
125
126
    }

    bool empty(vertex_t v)
    {
        return _sampler[v].empty();
    }

127
128
    template <class Edge>
    void remove(vertex_t v, vertex_t u, Edge&& e)
129
130
131
132
    {
        auto& sampler = _sampler[v];
        auto& sampler_pos = _sampler_pos[v];

133
134
        auto k = std::make_pair(u, _eindex[e]);
        remove_item(k, sampler, sampler_pos);
135
136
    }

137
138
    template <class Weight, class Edge>
    void insert(vertex_t v, vertex_t u, Weight w, Edge&& e)
139
140
141
    {
        auto& sampler = _sampler[v];
        auto& sampler_pos = _sampler_pos[v];
142
143
        auto k = std::make_pair(u, _eindex[e]);
        insert_item(k, w, sampler, sampler_pos);
144
145
146
    }

private:
147
148
    typedef std::pair<vertex_t, size_t> item_t;
    typedef gt_hash_map<item_t, size_t> pos_map_t;
149

150
151
    template <class RNG>
    const item_t& sample_item(std::vector<item_t>& sampler, RNG& rng)
152
    {
153
        return uniform_sample(sampler, rng);
154
155
    }

156
157
    template <class Sampler, class RNG>
    const item_t& sample_item(Sampler& sampler, RNG& rng)
158
    {
159
        return sampler.sample(rng);
160
161
    }

162
163
    void remove_item(item_t& u, std::vector<item_t>& sampler,
                     pos_map_t& sampler_pos)
164
    {
165
166
167
168
169
170
        auto& back = sampler.back();
        size_t pos = sampler_pos[u];
        sampler_pos[back] = pos;
        sampler[pos] = back;
        sampler.pop_back();
        sampler_pos.erase(u);
171
172
    }

173
174
175
    template <class Sampler>
    void remove_item(item_t& u, Sampler& sampler,
                     pos_map_t& sampler_pos)
176
    {
177
178
179
180
        size_t pos = sampler_pos[u];
        sampler.remove(pos);
        sampler_pos.erase(u);
    }
181
182


183
184
185
186
187
188
    template <class Weight>
    void insert_item(item_t& u, Weight, std::vector<item_t>& sampler,
                     pos_map_t& sampler_pos)
    {
        sampler_pos[u] = sampler.size();
        sampler.push_back(u);
189
190
    }

191
192
    template <class Weight>
    void insert_item(item_t& u, Weight w, DynamicSampler<item_t>& sampler,
193
                     pos_map_t& sampler_pos)
194
    {
195
196
        assert(sampler_pos.find(u) == sampler_pos.end());
        sampler_pos[u] = sampler.insert(u, w);
197
198
    }

199
200
201
    typedef typename std::conditional<Weighted::value,
                                      typename std::conditional<Dynamic::value,
                                                                DynamicSampler<item_t>,
202
203
                                                                Sampler<item_t,
                                                                        boost::mpl::false_>>::type,
204
205
206
207
208
                                      vector<item_t>>::type
        sampler_t;

    typedef typename vprop_map_t<sampler_t>::type vsampler_t;
    typename vsampler_t::unchecked_t _sampler;
209

210
    typedef typename vprop_map_t<pos_map_t>::type sampler_pos_t;
211
    typename sampler_pos_t::unchecked_t _sampler_pos;
212
213

    typename property_map<Graph, edge_index_t>::type _eindex;
214
215
216
217
};

}

218
#endif // GRAPH_NEIGHBOR_SAMPLER_HH