Commit 7be93ef4 authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

inference: fix issue with clabel in multilevel MCMC

This fixes issue #706
parent 98d213e0
......@@ -1489,6 +1489,8 @@ public:
{
auto& hb = _coupled_state->get_b();
hb[s] = hb[r];
auto& hpclabel = _coupled_state->get_pclabel();
hpclabel[s] = _pclabel[v];
}
return s;
}
......@@ -2263,6 +2265,11 @@ public:
return _b;
}
vprop_map_t<int32_t>::type::unchecked_t& get_bclabel()
{
return _bclabel;
}
vprop_map_t<int32_t>::type::unchecked_t& get_pclabel()
{
return _pclabel;
......
......@@ -832,22 +832,25 @@ struct Multilevel: public State
if (State::_has_b_min)
{
push_b(vs);
double S = 0;
if (rs.size() > B_min)
if (B_min == _B_min)
{
for (auto& v : vs)
push_b(vs);
double S = 0;
if (rs.size() > B_min)
{
auto r = State::get_group(v);
Group t = _b_min[v];
if (r == t)
continue;
S += State::virtual_move(v, r, t);
move_node(v, t);
for (auto& v : vs)
{
auto r = State::get_group(v);
Group t = _b_min[v];
if (r == t)
continue;
S += State::virtual_move(v, r, t);
move_node(v, t);
}
}
put_cache(B_min, S);
pop_b();
}
put_cache(B_min, S);
pop_b();
}
else if (B_min == 1)
{
......@@ -884,7 +887,7 @@ struct Multilevel: public State
rs.insert(s);
continue;
}
auto t = State::get_new_group(v, std::isinf(_beta), rng);
auto t = State::get_new_group(v, true, rng);
S += State::virtual_move(v, s, t);
move_node(v, t);
rs.insert(t);
......@@ -912,10 +915,11 @@ struct Multilevel: public State
B_max = B_max_init = std::min(rs.size(), B_max);
}
put_cache(rs.size(), S);
if (cache.find(rs.size()) == cache.end())
put_cache(rs.size(), S);
State::relax_update(false);
pop_b();
State::relax_update(false);
get_cache(rs.size(), rs);
}
else
......
......@@ -299,6 +299,9 @@ class MultilevelMCMCState(ABC):
def _get_entropy_args(self, kwargs):
pass
def _get_clabel(self):
return None
@mcmc_sweep_wrap
def multilevel_mcmc_sweep(self, niter=1, beta=1., c=.5, psingle=None,
pmultilevel=1, d=0.01, r=0.9, random_bisect=True,
......@@ -406,8 +409,11 @@ class MultilevelMCMCState(ABC):
global_moves = True
else:
global_moves = False
bclabel = self._get_bclabel()
if bclabel is not None:
B_min = max(len(numpy.unique(bclabel.fa)), B_min)
if b_min is None:
b_min = self.g.new_vp("int")
b_min = self.g.vertex_index.copy("int")
if b_max is None:
b_max = self.g.new_vp("int")
......
......@@ -224,6 +224,12 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
self.deg_corr = deg_corr
self.overlap = False
if clabel is None:
clabel = pclabel
if b is None:
b = clabel
if B is None and b is None:
B = 1
......@@ -315,8 +321,6 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
else:
self.clabel = self.g.new_vp("int")
self.clabel.fa = clabel
elif self.pclabel.fa.max() > 0:
self.clabel = self.pclabel
else:
self.clabel = self.g.new_vp("int")
......@@ -648,16 +652,19 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
rec_params.append("microcanonical")
rec_params = kwargs.pop("rec_params", rec_params)
if b is None:
b = self.get_bclabel()
state = BlockState(bg,
eweight=eweight,
vweight=vweight,
b=bg.vertex_index.copy("int") if b is None else b,
b=b,
deg_corr=deg_corr,
rec_types=rec_types,
recs=recs,
drec=drec,
rec_params=rec_params,
clabel=kwargs.pop("clabel", self.get_bclabel()),
clabel=kwargs.pop("clabel", self.bclabel),
pclabel=kwargs.pop("pclabel", self.get_bpclabel()),
dense_bg=self.dense_bg,
epsilon=kwargs.pop("epsilon",
......@@ -665,7 +672,7 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
**kwargs)
if copy_coupled and self._coupled_state is not None:
state._couple_state(state.get_block_state(b=state.get_bclabel(),
state._couple_state(state.get_block_state(b=self.bclabel,
copy_bg=False,
vweight="nonempty",
Lrecdx=state.Lrecdx),
......@@ -1325,6 +1332,8 @@ class BlockState(MCMCState, MultiflipMCMCState, MultilevelMCMCState,
return libinference.multiflip_mcmc_sweep_parallel(mcmc_states,
[s._state for s in states],
_get_rng())
def _get_bclabel(self):
return self.bclabel
def _multilevel_mcmc_sweep_dispatch(self, mcmc_state):
return libinference.multilevel_mcmc_sweep(mcmc_state, self._state,
......
......@@ -85,7 +85,7 @@ class LayeredBlockState(OverlapBlockState, BlockState):
"""
def __init__(self, g, ec, eweight=None, vweight=None, recs=[], rec_types=[],
rec_params=[], b=None, B=None, clabel=None, pclabel=False,
rec_params=[], b=None, B=None, clabel=None, pclabel=None,
layers=False, deg_corr=True, overlap=False, **kwargs):
kwargs = kwargs.copy()
......@@ -945,6 +945,8 @@ class LayeredBlockState(OverlapBlockState, BlockState):
return libinference.multiflip_mcmc_layered_overlap_sweep_parallel(mcmc_states,
[s._state for s in states],
_get_rng())
def _get_bclabel(self):
return self.agg_state._get_bclabel()
def _multilevel_mcmc_sweep_dispatch(self, mcmc_state):
if not self.overlap:
......
......@@ -240,17 +240,11 @@ def minimize_nested_blockmodel_dl(g, init_bs=None,
"""
L = int(numpy.ceil(numpy.log2(g.num_vertices())))
if init_bs is None:
bs = [numpy.zeros(1)] * (L + 1)
else:
bs = init_bs
state = state(g, bs=bs, **state_args)
state = state(g, bs=init_bs, **state_args)
args = dict(niter=1, psingle=0, beta=numpy.inf)
args.update(multilevel_mcmc_args)
l = 0
while l >= 0:
......
......@@ -67,13 +67,6 @@ class NestedBlockState(object):
hstate_args={}, hentropy_args={}, sampling=True, **kwargs):
self.g = g
if bs is None:
if base_type is OverlapBlockState or state_args.get("overlap", False):
b = zeros(2 * g.num_edges(), dtype="int")
else:
b = zeros(g.num_vertices(), dtype="int")
bs = [b] + [zeros(1, dtype="int")] * int(ceil(log2(len(b))))
self.base_type = base_type
if base_type is LayeredBlockState:
self.Lrecdx = []
......@@ -104,7 +97,17 @@ class NestedBlockState(object):
recs=True,
recs_dl=False,
beta_dl=1.)
if bs is None:
if base_type is OverlapBlockState:
N = 2 * g.num_edges()
else:
N = g.num_vertices()
L = int(numpy.ceil(numpy.log2(N)))
bs = [None] * (L + 1)
self.levels = [base_type(g, b=bs[0], **self.state_args)]
for i, b in enumerate(bs[1:]):
state = self.levels[-1]
args = self.hstate_args
......
......@@ -91,6 +91,12 @@ class OverlapBlockState(BlockState):
if node_index is not None and self.base_g is None:
raise ValueError("Must specify base graph if node_index is specified...")
if clabel is None:
clabel = pclabel
if b is None:
b = clabel
if B is None and b is None:
B = 1
......@@ -619,6 +625,9 @@ class OverlapBlockState(BlockState):
[s._state for s in states],
_get_rng())
def _get_bclabel(self):
return self.bclabel
def _multilevel_mcmc_sweep_dispatch(self, mcmc_state):
return libinference.overlap_multilevel_mcmc_sweep(mcmc_state, self._state,
_get_rng())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment