Commit 90cda9bd authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix bug with CovariateBlockState.merge_layers()

parent a9b02d3e
......@@ -654,6 +654,16 @@ void do_split_graph(GraphInterface& gi, boost::any& aec, boost::any& ab,
std::ref(uvmap)))();
}
bool bmap_has(const bmap_t& bmap, size_t c, size_t r)
{
if (c > bmap.size())
throw GraphException("invalid covariate value:" + lexical_cast<string>(c));
auto iter = bmap[c].find(r);
if (iter == bmap[c].end())
return false;
return true;
}
size_t bmap_get(const bmap_t& bmap, size_t c, size_t r)
{
if (c > bmap.size())
......@@ -688,6 +698,7 @@ bmap_t bmap_copy(const bmap_t& bmap)
void export_blockmodel_covariate()
{
boost::python::class_<bmap_t>("bmap_t")
.def("has", bmap_has)
.def("get", bmap_get)
.def("set", bmap_set)
.def("del_c", bmap_del_c)
......
......@@ -240,7 +240,7 @@ class CovariateBlockState(BlockState):
if openmp_enabled():
nt = openmp_get_num_threads()
B = u.num_vertices() + 2 * nt
B = max(B, u.vp["b"].a.max() + 1 + 2 * nt)
#B = max(B, u.vp["b"].a.max() + 1 + 2 * nt)
if not self.overlap:
state = BlockState(u, b=u.vp["b"],
B=B,
......@@ -335,15 +335,13 @@ class CovariateBlockState(BlockState):
src_rbmap = {}
r_max = 0
for r in range(self.B):
try:
if self.bmap.has(l_tgt + 1, r):
tgt_bmap[r] = self.bmap.get(l_tgt + 1, r)
r_max = max(r_max, tgt_bmap[r])
except RuntimeError:
pass
try:
if self.bmap.has(l_src + 1, r):
src_rbmap[self.bmap.get(l_src + 1, r)] = r
except RuntimeError:
pass
r_missing = list(set(range(r_max)) - set(tgt_bmap.values()))
r_max += 1
if self.overlap:
......@@ -356,10 +354,15 @@ class CovariateBlockState(BlockState):
if r in tgt_bmap:
nb[i] = tgt_bmap[r]
else:
self.bmap.set(l_tgt + 1, r, r_max)
nb[i] = r_max
tgt_bmap[r] = r_max
r_max += 1
if len(r_missing) > 0:
rr = r_missing[0]
del r_missing[0]
else:
rr = r_max
r_max += 1
self.bmap.set(l_tgt + 1, r, rr)
nb[i] = rr
tgt_bmap[r] = rr
b[e] = nb
b_src = b
b_tgt = u_tgt_base.ep["b"]
......@@ -374,10 +377,15 @@ class CovariateBlockState(BlockState):
if r in tgt_bmap:
b[v] = tgt_bmap[r]
else:
self.bmap.set(l_tgt + 1, r, r_max)
b[v] = r_max
tgt_bmap[r] = r_max
r_max += 1
if len(r_missing) > 0:
rr = r_missing[0]
del r_missing[0]
else:
rr = r_max
r_max += 1
self.bmap.set(l_tgt + 1, r, rr)
b[v] = rr
tgt_bmap[r] = rr
b_src = b
b_tgt = u_tgt_base.vp["b"]
......@@ -387,10 +395,12 @@ class CovariateBlockState(BlockState):
(u_tgt_base.ep["weight"], u_src_base.ep["weight"])]
if not self.overlap:
props.append((u_tgt_base.vp["brmap"], u_src_base.vp["brmap"]))
props.append((u_tgt_base.vp["brmap"],
u_src_base.vp["brmap"]))
u, props = graph_union(u_tgt_base, u_src_base,
intersection=intersection, props=props,
intersection=intersection,
props=props,
include=False)
if self.overlap:
......@@ -427,7 +437,7 @@ class CovariateBlockState(BlockState):
self.base_ec.a[self.base_ec.a > l_src] -= 1
self.C -= 1
old_bmap = self.bmap.copy()
self.bmap.del_c(l_src)
self.bmap.del_c(l_src + 1)
self.__bg = None
yield
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment