Commit c688f9a3 authored by Tiago Peixoto's avatar Tiago Peixoto

blockmodel: Refactor neighbour and edge sampling

parent d953e5f0
......@@ -1755,9 +1755,7 @@ public:
EGroups<g_t, is_weighted_t> _egroups;
typedef typename std::conditional<is_weighted_t::value,
WeightedNeighbourSampler<g_t, DynamicSampler>,
UnweightedNeighbourSampler<g_t>>::type
typedef NeighbourSampler<g_t, is_weighted_t, std::true_type>
neighbour_sampler_t;
neighbour_sampler_t _neighbour_sampler;
......
......@@ -1329,7 +1329,11 @@ public:
_egroups.resize(num_vertices(bg));
for (auto e : edges_range(g))
{
_epos[e] = make_pair(numeric_limits<size_t>::max(),
numeric_limits<size_t>::max());
insert_edge(e, eweight[e], b, g);
}
}
void clear()
......@@ -1342,6 +1346,40 @@ public:
return _egroups.empty();
}
template <class Vprop>
bool check(Vprop b, Graph& g)
{
for (size_t r = 0; r < _egroups.size(); ++r)
{
auto& edges = _egroups[r];
for (size_t i = 0; i < edges.size(); ++i)
{
const auto& e = edges[i];
if (!is_valid(i, edges))
continue;
if (size_t(b[source(get<0>(e), g)]) != r &&
size_t(b[target(get<0>(e), g)]) != r)
{
assert(false);
return false;
}
}
}
return true;
}
template <class Edge>
bool is_valid(size_t i, DynamicSampler<Edge>& elist)
{
return elist.is_valid(i);
}
template <class Edge>
bool is_valid(size_t, vector<Edge>& elist)
{
return true;
}
template <class Edge, class Vprop>
void insert_edge(const Edge& e, size_t weight, Vprop& b, Graph& g)
{
......@@ -1361,6 +1399,7 @@ public:
{
if (pos < elist.size() && elist[pos] == e)
return;
assert(pos >= elist.size() || elist[pos] != e);
elist.push_back(e);
pos = elist.size() - 1;
}
......@@ -1369,8 +1408,9 @@ public:
void insert_edge(const Edge& e, DynamicSampler<Edge>& elist,
size_t weight, size_t& pos)
{
if (pos < elist.size() && elist[pos] == e)
if (pos < elist.size() && elist.is_valid(pos) && elist[pos] == e)
return;
assert(pos >= elist.size() || !elist.is_valid(pos) || elist[pos] != e);
pos = elist.insert(e, weight);
}
......
......@@ -28,35 +28,39 @@
namespace graph_tool
{
template <class Graph>
class UnweightedNeighbourSampler
template <class Graph, class Weighted, class Dynamic>
class NeighbourSampler
{
public:
typedef typename boost::graph_traits<Graph>::vertex_descriptor vertex_t;
template <class Eprop>
UnweightedNeighbourSampler(Graph& g, Eprop&, bool self_loops=false)
: _sampler_c(get(vertex_index_t(), g)),
_sampler(_sampler_c.get_unchecked()),
_sampler_pos_c(get(vertex_index_t(), g)),
_sampler_pos(_sampler_pos_c.get_unchecked())
NeighbourSampler(Graph& g, Eprop& eweight, bool self_loops=false)
: _sampler(get(vertex_index_t(), g), num_vertices(g)),
_sampler_pos(get(vertex_index_t(), g), num_vertices(g)),
_eindex(get(edge_index_t(), g))
{
for (auto v : vertices_range(g))
for (auto e : edges_range(g))
{
auto& sampler = _sampler_c[v];
auto& sampler_pos = _sampler_pos_c[v];
sampler.clear();
sampler_pos.clear();
for (auto e : all_edges_range(v, g))
auto u = source(e, g);
auto v = target(e, g);
if (!self_loops && u == v)
continue;
auto w = eweight[e];
if (w == 0)
continue;
if (u == v)
{
insert(v, u, w, e);
}
else
{
auto u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
sampler_pos[u].push_back(sampler.size());
sampler.push_back(u); // Self-loops will be added twice
insert(v, u, w, e);
insert(u, v, w, e);
}
}
}
......@@ -64,7 +68,9 @@ public:
template <class RNG>
vertex_t sample(vertex_t v, RNG& rng)
{
return uniform_sample(_sampler[v], rng);
auto& sampler = _sampler[v];
auto& item = sample_item(sampler, rng);
return item.first;
}
bool empty(vertex_t v)
......@@ -72,135 +78,92 @@ public:
return _sampler[v].empty();
}
template <class Weight>
void remove(vertex_t v, vertex_t u, Weight)
template <class Edge>
void remove(vertex_t v, vertex_t u, Edge&& e)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
auto& pos_u = sampler_pos[u];
size_t i = pos_u.back();
sampler[i] = sampler.back();
for (auto& j : sampler_pos[sampler.back()])
{
if (j == sampler.size() - 1)
{
j = i;
break;
}
}
sampler.pop_back();
pos_u.pop_back();
if (pos_u.empty())
sampler_pos.erase(u);
auto k = std::make_pair(u, _eindex[e]);
remove_item(k, sampler, sampler_pos);
}
template <class Weight>
void insert(vertex_t v, vertex_t u, Weight)
template <class Weight, class Edge>
void insert(vertex_t v, vertex_t u, Weight w, Edge&& e)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
sampler_pos[u].push_back(sampler.size());
sampler.push_back(u);
auto k = std::make_pair(u, _eindex[e]);
insert_item(k, w, sampler, sampler_pos);
}
private:
typedef typename vprop_map_t<std::vector<vertex_t>>::type sampler_t;
sampler_t _sampler_c;
typename sampler_t::unchecked_t _sampler;
typedef std::pair<vertex_t, size_t> item_t;
typedef gt_hash_map<item_t, size_t> pos_map_t;
typedef typename vprop_map_t<gt_hash_map<vertex_t, vector<size_t>>>::type sampler_pos_t;
sampler_pos_t _sampler_pos_c;
typename sampler_pos_t::unchecked_t _sampler_pos;
};
template <class Graph, template <class V> class Sampler>
class WeightedNeighbourSampler
{
public:
typedef typename boost::graph_traits<Graph>::vertex_descriptor vertex_t;
template <class Eprop>
WeightedNeighbourSampler(Graph& g, Eprop& eweight, bool self_loops=false)
: _sampler_c(get(vertex_index_t(), g)),
_sampler(_sampler_c.get_unchecked()),
_sampler_pos_c(get(vertex_index_t(), g)),
_sampler_pos(_sampler_pos_c.get_unchecked())
template <class RNG>
const item_t& sample_item(std::vector<item_t>& sampler, RNG& rng)
{
for (auto v : vertices_range(g))
{
auto& sampler = _sampler_c[v];
auto& sampler_pos = _sampler_pos_c[v];
sampler.clear();
sampler_pos.clear();
for (auto e : all_edges_range(v, g))
{
auto u = target(e, g);
if (is_directed::apply<Graph>::type::value && u == v)
u = source(e, g);
if (!self_loops && u == v)
continue;
auto w = eweight[e];
if (w == 0)
continue;
insert(v, u, w); // Self-loops will be added twice, and hence will
// be sampled with probability 2 * eweight[e]
}
}
return uniform_sample(sampler, rng);
}
template <class RNG>
vertex_t sample(vertex_t v, RNG& rng)
template <class Sampler, class RNG>
const item_t& sample_item(Sampler& sampler, RNG& rng)
{
return _sampler[v].sample(rng);
return sampler.sample(rng);
}
bool empty(vertex_t v)
void remove_item(item_t& u, std::vector<item_t>& sampler,
pos_map_t& sampler_pos)
{
return _sampler[v].empty();
auto& back = sampler.back();
size_t pos = sampler_pos[u];
sampler_pos[back] = pos;
sampler[pos] = back;
sampler.pop_back();
sampler_pos.erase(u);
}
template <class Weight>
void remove(vertex_t v, vertex_t u, Weight w)
template <class Sampler>
void remove_item(item_t& u, Sampler& sampler,
pos_map_t& sampler_pos)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
size_t pos = sampler_pos[u];
sampler.remove(pos);
sampler_pos.erase(u);
}
auto& pos = sampler_pos[u];
sampler.remove(pos.first);
w -= pos.second;
if (w > 0)
insert(v, u, w);
else
sampler_pos.erase(u);
template <class Weight>
void insert_item(item_t& u, Weight, std::vector<item_t>& sampler,
pos_map_t& sampler_pos)
{
sampler_pos[u] = sampler.size();
sampler.push_back(u);
}
template <class Weight>
void insert(vertex_t v, vertex_t u, Weight w)
template <class Sampler, class Weight>
void insert_item(item_t& u, Weight w, Sampler& sampler,
pos_map_t& sampler_pos)
{
auto& sampler = _sampler[v];
auto& sampler_pos = _sampler_pos[v];
auto pos = sampler_pos.find(u);
if (pos != sampler_pos.end())
{
auto old_w = pos->second.second;
remove(v, u, old_w);
w += old_w;
}
sampler_pos[u] = std::make_pair(sampler.insert(u, w), w);
assert(sampler_pos.find(u) == sampler_pos.end());
sampler_pos[u] = sampler.insert(u, w);
}
private:
typedef typename vprop_map_t<Sampler<vertex_t>>::type sampler_t;
sampler_t _sampler_c;
typename sampler_t::unchecked_t _sampler;
typedef typename std::conditional<Weighted::value,
typename std::conditional<Dynamic::value,
DynamicSampler<item_t>,
Sampler<item_t>>::type,
vector<item_t>>::type
sampler_t;
typedef typename vprop_map_t<sampler_t>::type vsampler_t;
typename vsampler_t::unchecked_t _sampler;
typedef typename vprop_map_t<gt_hash_map<vertex_t, pair<size_t, double>>>::type sampler_pos_t;
sampler_pos_t _sampler_pos_c;
typedef typename vprop_map_t<pos_map_t>::type sampler_pos_t;
typename sampler_pos_t::unchecked_t _sampler_pos;
typename property_map<Graph, edge_index_t>::type _eindex;
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment