Commit a014890f authored by Tiago Peixoto's avatar Tiago Peixoto

Fix issue with nested SBM mcmc

This solves a problem when NestedBlockState.mcmc_sweep() is invoked with
nonstandard values of the parameter 'c'.
parent 3a609c11
Pipeline #201 failed with stage
in 1150 minutes and 43 seconds
......@@ -164,6 +164,8 @@ void export_blockmodel_state()
&state_t::virtual_move;
size_t (state_t::*sample_block)(size_t, double, rng_t&) =
&state_t::sample_block;
size_t (state_t::*random_neighbour)(size_t, rng_t&) =
&state_t::random_neighbour;
double (state_t::*get_move_prob)(size_t, size_t, size_t, double,
bool) =
&state_t::get_move_prob;
......@@ -184,6 +186,7 @@ void export_blockmodel_state()
.def("virtual_move", virtual_move)
.def("merge_vertices", merge_vertices)
.def("sample_block", sample_block)
.def("sample_neighbour", random_neighbour)
.def("entropy", &state_t::entropy)
.def("get_partition_dl", &state_t::get_partition_dl)
.def("get_deg_dl", &state_t::get_deg_dl)
......@@ -200,6 +203,8 @@ void export_blockmodel_state()
&state_t::decouple_state)
.def("clear_egroups",
&state_t::clear_egroups)
.def("rebuild_neighbour_sampler",
&state_t::rebuild_neighbour_sampler)
.def("sync_emat",
&state_t::sync_emat);
});
......
......@@ -115,8 +115,7 @@ public:
_m_entries(num_vertices(_bg)),
_coupled_state(nullptr)
{
init_neighbour_sampler(_g, _eweight, _neighbour_sampler);
rebuild_neighbour_sampler();
_empty_blocks.clear();
_candidate_blocks.clear();
for (auto r : vertices_range(_bg))
......@@ -1296,9 +1295,7 @@ public:
return sample_block<rng_t>(v, c, rng);
}
template <class RNG>
size_t random_neighbour(size_t v, RNG& rng)
size_t random_neighbour(size_t v, rng_t& rng)
{
if (_neighbour_sampler[v].size() == 0)
return v;
......@@ -1375,10 +1372,18 @@ public:
};
for (auto e : out_edges_range(v, _g))
{
if (target(e, _g) == v)
continue;
sum_prob(e, target(e, _g));
}
for (auto e : in_edges_range(v, _g))
{
if (source(e, _g) == v)
continue;
sum_prob(e, source(e, _g));
}
if (w > 0)
return p / w;
......@@ -1685,6 +1690,11 @@ public:
_egroups.clear();
}
void rebuild_neighbour_sampler()
{
init_neighbour_sampler(_g, _eweight, _neighbour_sampler);
}
void sync_emat()
{
_emat.sync(_g, _b, _bg);
......
......@@ -1508,7 +1508,7 @@ private:
template <class Vertex, class Graph, class Eprop, class SMap>
void build_neighbour_sampler(Vertex v, SMap& sampler, Eprop& eweight, Graph& g,
bool self_loops=true)
bool self_loops=false)
{
vector<Vertex> neighbours;
vector<double> probs;
......@@ -1539,7 +1539,7 @@ void build_neighbour_sampler(Vertex v, SMap& sampler, Eprop& eweight, Graph& g,
template <class Vertex, class Graph, class Eprop>
void build_neighbour_sampler(Vertex v, vector<size_t>& sampler, Eprop&, Graph& g,
bool self_loops=true)
bool self_loops=false)
{
sampler.clear();
for (auto e : all_edges_range(v, g))
......
......@@ -468,8 +468,6 @@ class NestedBlockState(object):
nmoves = 0
c = kwargs.get("c", None)
if c is not None and not isinstance(c, collections.Iterable):
c = [c] * len(self.levels)
for l in range(len(self.levels)):
if check_verbose(verbose):
......@@ -493,6 +491,7 @@ class NestedBlockState(object):
if l > 0:
self.levels[l]._state.sync_emat()
self.levels[l]._state.rebuild_neighbour_sampler()
if l < len(self.levels) - 1:
self.levels[l + 1]._state.sync_emat()
......@@ -510,22 +509,21 @@ class NestedBlockState(object):
def mcmc_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a Metropolis-Hastings acceptance-rejection
sampling MCMC to sample hierarchical network partitions.
MCMC to sample hierarchical network partitions.
The arguments accepted are the same as in
:method:`graph_tool.inference.BlockState.mcmc_sweep`.
"""
c = kwargs.get("c", None)
if c is None:
c = [1] + [numpy.inf] * (len(self.levels) - 1)
kwargs = kwargs.copy()
kwargs["c"] = c
return self._h_sweep(lambda s, **a: s.mcmc_sweep(**a), **kwargs)
c = extract_arg(kwargs, "c", 1)
if not isinstance(c, collections.Iterable):
c = [c] + [c * 2 ** l for l in range(1, len(self.levels))]
return self._h_sweep(lambda s, **a: s.mcmc_sweep(**a), c=c, **kwargs)
def gibbs_sweep(self, **kwargs):
r"""Perform ``niter`` sweeps of a rejection-free Gibbs sampling MCMC
to sample network partitions.
r"""Perform ``niter`` sweeps of a rejection-free Gibbs MCMC to sample network
partitions.
The arguments accepted are the same as in
:method:`graph_tool.inference.BlockState.gibbs_sweep`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment