Commit 1d7bbba4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix bug in BlockState.get_edges_prob()

parent acfeb0ac
Pipeline #142 failed with stage
......@@ -218,26 +218,34 @@ void export_blockmodel_state()
([&](auto* s)
{
typedef typename std::remove_reference<decltype(*s)>::type state_t;
void (state_t::*remove_vertex)(size_t) =
&state_t::remove_vertex;
void (state_t::*add_vertex)(size_t, size_t) =
&state_t::add_vertex;
void (state_t::*remove_vertices)(python::object) =
&state_t::remove_vertices;
void (state_t::*add_vertices)(python::object, python::object) =
&state_t::add_vertices;
double (state_t::*virtual_move)(size_t, size_t, bool, bool, bool,
bool, bool) =
&state_t::virtual_move;
size_t (state_t::*sample_block)(size_t, double, vector<size_t>&,
rng_t&)
= &state_t::sample_block;
rng_t&) =
&state_t::sample_block;
double (state_t::*get_move_prob)(size_t, size_t, size_t, double,
bool)
= &state_t::get_move_prob;
void (state_t::*merge_vertices)(size_t, size_t)
= &state_t::merge_vertices;
void (state_t::*set_partition)(boost::any&)
= &state_t::set_partition;
bool) =
&state_t::get_move_prob;
void (state_t::*merge_vertices)(size_t, size_t) =
&state_t::merge_vertices;
void (state_t::*set_partition)(boost::any&) =
&state_t::set_partition;
class_<state_t> c(name_demangle(typeid(state_t).name()).c_str(),
no_init);
c.def("remove_vertex", &state_t::remove_vertex)
.def("add_vertex", &state_t::add_vertex)
c.def("remove_vertex", remove_vertex)
.def("add_vertex", add_vertex)
.def("remove_vertices", remove_vertices)
.def("add_vertices", add_vertices)
.def("move_vertex", &state_t::move_vertex)
.def("set_partition", set_partition)
.def("virtual_move", virtual_move)
......
......@@ -94,7 +94,8 @@ public:
}
// remove a vertex from its current block
void remove_vertex(size_t v)
template <class EFilt>
void remove_vertex(size_t v, EFilt&& efilt)
{
typedef typename graph_traits<g_t>::vertex_descriptor vertex_t;
......@@ -103,12 +104,15 @@ public:
int self_weight = 0;
for (auto e : out_edges_range(v, _g))
{
if (efilt(e))
continue;
vertex_t u = target(e, _g);
vertex_t s = _b[u];
auto& me = _emat.get_bedge(e);
size_t ew = _eweight[e];
auto ew = _eweight[e];
if (u == v && !is_directed::apply<g_t>::type::value)
{
self_weight += ew;
......@@ -141,6 +145,9 @@ public:
for (auto e : in_edges_range(v, _g))
{
if (efilt(e))
continue;
vertex_t u = source(e, _g);
if (u == v)
continue;
......@@ -148,7 +155,7 @@ public:
auto& me = _emat.get_bedge(e);
size_t ew = _eweight[e];
auto ew = _eweight[e];
_mrs[me] -= ew;
_mrp[s] -= ew;
......@@ -168,8 +175,66 @@ public:
_eweight, _degs);
}
void remove_vertex(size_t v)
{
remove_vertex(v, [](auto&){ return false; });
}
template <class Vlist>
void remove_vertices(Vlist& vs)
{
typedef typename graph_traits<g_t>::vertex_descriptor vertex_t;
gt_hash_set<vertex_t> vset(vs.begin(), vs.end());
typedef typename graph_traits<g_t>::edge_descriptor edges_t;
gt_hash_set<edges_t> eset;
for (auto v : vset)
{
for (auto e : all_edges_range(v, _g))
{
auto u = (source(e, _g) == v) ? target(e, _g) : source(e, _g);
if (vset.find(u) != vset.end())
eset.insert(e);
}
}
for (auto v : vset)
remove_vertex(v, [&](auto& e) { return eset.find(e) != eset.end(); });
for (auto& e : eset)
{
vertex_t v = source(e, _g);
vertex_t u = target(e, _g);
vertex_t r = _b[v];
vertex_t s = _b[u];
auto& me = _emat.get_bedge(e);
auto ew = _eweight[e];
_mrs[me] -= ew;
assert(_mrs[me] >= 0);
_mrp[r] -= ew;
_mrm[s] -= ew;
if (_mrs[me] == 0)
_emat.remove_me(r, s, me, _bg);
}
}
void remove_vertices(python::object ovs)
{
vector<size_t> vs;
for (int i = 0; i < python::len(ovs); ++i)
vs.push_back(python::extract<size_t>(ovs[i]));
remove_vertices(vs);
}
// add a vertex to block r
void add_vertex(size_t v, size_t r)
template <class Efilt>
void add_vertex(size_t v, size_t r, Efilt&& efilt)
{
typedef typename graph_traits<g_t>::vertex_descriptor vertex_t;
typedef typename graph_traits<bg_t>::edge_descriptor bedge_t;
......@@ -177,6 +242,8 @@ public:
int self_weight = 0;
for (auto e : out_edges_range(v, _g))
{
if (efilt(e))
continue;
vertex_t u = target(e, _g);
vertex_t s;
......@@ -225,6 +292,9 @@ public:
for (auto e : in_edges_range(v, _g))
{
if (efilt(e))
continue;
vertex_t u = source(e, _g);
if (u == v)
continue;
......@@ -263,6 +333,80 @@ public:
_eweight, _degs);
}
void add_vertex(size_t v, size_t r)
{
add_vertex(v, r, [](auto&) { return false; });
}
template <class Vlist, class Blist>
void add_vertices(Vlist& vs, Blist& rs)
{
typedef typename graph_traits<g_t>::vertex_descriptor vertex_t;
typedef typename graph_traits<bg_t>::edge_descriptor bedge_t;
gt_hash_map<vertex_t, size_t> vset;
for (size_t i = 0; i < vs.size(); ++i)
vset[vs[i]] = rs[i];
typedef typename graph_traits<g_t>::edge_descriptor edges_t;
gt_hash_set<edges_t> eset;
for (auto vr : vset)
{
auto v = vr.first;
for (auto e : all_edges_range(v, _g))
{
auto u = (source(e, _g) == v) ? target(e, _g) : source(e, _g);
if (vset.find(u) != vset.end())
eset.insert(e);
}
}
for (auto vr : vset)
add_vertex(vr.first, vr.second,
[&](auto& e){ return eset.find(e) != eset.end(); });
for (auto e : eset)
{
vertex_t v = source(e, _g);
vertex_t u = target(e, _g);
vertex_t r = vset[v];
vertex_t s = vset[u];
auto me = _emat.get_me(r, s);
if (me == bedge_t())
{
me = add_edge(r, s, _bg).first;
_emat.put_me(r, s, me);
_c_mrs[me] = 0;
}
_emat.get_bedge(e) = me;
assert(_emat.get_bedge(e) != bedge_t());
assert(me == _emat.get_me(r, s));
auto ew = _eweight[e];
_mrs[me] += ew;
_mrp[r] += ew;
_mrm[s] += ew;
}
}
void add_vertices(python::object ovs, python::object ors)
{
vector<size_t> vs;
vector<size_t> rs;
for (int i = 0; i < python::len(ovs); ++i)
{
vs.push_back(python::extract<size_t>(ovs[i]));
rs.push_back(python::extract<size_t>(ors[i]));
}
add_vertices(vs, rs);
}
// move a vertex from its current block to block nr
void move_vertex(size_t v, size_t nr)
{
......
......@@ -32,6 +32,7 @@ import random
from numpy import *
import numpy
import copy
import collections
from . util import *
......@@ -600,23 +601,33 @@ class BlockState(object):
def remove_vertex(self, v):
r"""Remove vertex ``v`` from its current group.
This optionally accepts a list of vertices to remove.
.. warning::
This will leave the state in an inconsistent state before the vertex
is returned to some other group, or if the same vertex is removed
twice.
"""
self._state.remove_vertex(int(v))
if isinstance(v, collections.Iterable):
self._state.remove_vertices(list(v))
else:
self._state.remove_vertex(int(v))
def add_vertex(self, v, r):
r"""Add vertex ``v`` to block ``r``.
This optionally accepts a list of vertices and blocks to add.
.. warning::
This can leave the state in an inconsistent state if a vertex is
added twice to the same group.
"""
self._state.add_vertex(int(v), r)
if isinstance(v, collections.Iterable):
self._state.add_vertices(list(v), list(r))
else:
self._state.add_vertex(int(v), r)
def merge_vertices(self, u, v):
r"""Merge vertex ``u`` into ``v``.
......@@ -665,39 +676,41 @@ class BlockState(object):
Si = self.entropy(**entropy_args)
for v in pos.keys():
self.remove_vertex(v)
self.remove_vertex(pos.keys())
try:
if missing:
new_es = []
for u, v in edge_list:
e = self.g.add_edge(u, v)
if self.is_edge_weighted:
self.eweight[e] = 1
new_es.append(e)
self.E += 1
else:
old_es = []
for e in edge_list:
u, v = e
if isinstance(e, tuple):
u, v = e
tmp = self.g.edge(u, v)
if tmp is None:
e = self.g.edge(u, v)
if e is None:
raise ValueError("edge not found: (%d, %d)" % (int(u),
int(v)))
self.g.remove_edge(tmp)
if self.is_edge_weighted:
self.eweight[e] -= 1
if self.eweight[e] == 0:
self.g.remove_edge(e)
else:
u, v = e
self.g.remove_edge(e)
old_es.append((u, v))
self.E -= 1
for v in pos.keys():
self.add_vertex(v, pos[v])
self.add_vertex(pos.keys(), pos.values())
Sf = self.entropy(**entropy_args)
for v in pos.keys():
self.remove_vertex(v)
self.remove_vertex(pos.keys())
finally:
if missing:
......@@ -706,10 +719,16 @@ class BlockState(object):
self.E -= 1
else:
for u, v in old_es:
self.g.add_edge(u, v)
if self.is_edge_weighted:
e = self.g.edge(u, v)
if e is None:
e = self.g.add_edge(u, v)
self.eweight[e] = 0
self.eweight[e] += 1
else:
self.g.add_edge(u, v)
self.E += 1
for v in pos.keys():
self.add_vertex(v, pos[v])
self.add_vertex(pos.keys(), pos.values())
if missing:
return Si - Sf
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment