Commit 81d28712 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

blockmodel: Improve weighted degrees

parent 68abf8f3
......@@ -44,7 +44,8 @@ GEN_DISPATCH(block_state, BlockState, BLOCK_STATE_params)
python::object make_block_state(boost::python::object ostate,
rng_t& rng);
degs_map_t get_block_degs(GraphInterface& gi, boost::any ab, boost::any aweight)
degs_map_t get_block_degs(GraphInterface& gi, boost::any ab, boost::any aweight,
size_t B)
{
degs_map_t degs;
vmap_t b = boost::any_cast<vmap_t>(ab);
......@@ -52,7 +53,7 @@ degs_map_t get_block_degs(GraphInterface& gi, boost::any ab, boost::any aweight)
[&](auto& g, auto& eweight)
{
std::vector<gt_hash_map<std::tuple<size_t, size_t>,
size_t>> hist;
size_t>> hist(B);
for (auto v : vertices_range(g))
{
size_t r = b[v];
......@@ -63,7 +64,7 @@ degs_map_t get_block_degs(GraphInterface& gi, boost::any ab, boost::any aweight)
hist[r][std::make_tuple(kin, kout)]++;
}
for (size_t r = 0; r < hist.size(); ++r)
for (size_t r = 0; r < B; ++r)
{
auto& deg = degs[r];
for (auto& kn : hist[r])
......@@ -77,7 +78,7 @@ degs_map_t get_block_degs(GraphInterface& gi, boost::any ab, boost::any aweight)
}
degs_map_t get_weighted_block_degs(GraphInterface& gi, degs_map_t& degs,
boost::any ab)
boost::any ab, size_t B)
{
degs_map_t ndegs;
vmap_t b = boost::any_cast<vmap_t>(ab);
......@@ -85,7 +86,7 @@ degs_map_t get_weighted_block_degs(GraphInterface& gi, degs_map_t& degs,
[&](auto& g)
{
std::vector<gt_hash_map<std::tuple<size_t, size_t>,
size_t>> hist;
size_t>> hist(B);
for (auto v : vertices_range(g))
{
size_t r = b[v];
......@@ -97,7 +98,7 @@ degs_map_t get_weighted_block_degs(GraphInterface& gi, degs_map_t& degs,
h[std::make_tuple(get<0>(k), get<1>(k))] += get<2>(k);
}
for (size_t r = 0; r < hist.size(); ++r)
for (size_t r = 0; r < B; ++r)
{
auto& deg = ndegs[r];
for (auto& kn : hist[r])
......
......@@ -231,6 +231,13 @@ class BlockState(object):
self.degs = kwargs.pop("degs", libinference.simple_degs_t())
if self.degs is None:
self.degs = libinference.simple_degs_t()
elif self.degs == "weighted":
idx_ = self.g.vertex_index.copy("int")
self.degs = libinference.get_block_degs(self.g._Graph__graph,
_prop("v", self.g, idx_),
self.eweight._get_any(),
self.g.num_vertices(True))
# ensure we have at most as many blocks as nodes
if B is not None and b is None:
......@@ -506,12 +513,14 @@ class BlockState(object):
if isinstance(self.degs, libinference.simple_degs_t):
degs = libinference.get_block_degs(self.g._Graph__graph,
_prop("v", self.g, self.b),
self.eweight._get_any())
self.eweight._get_any(),
self.get_B())
else:
degs = libinference.get_weighted_block_degs(self.g._Graph__graph,
self.degs,
_prop("v", self.g,
self.b))
self.b),
self.get_B())
else:
degs = libinference.simple_degs_t()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment