Commit d0fb50fc authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

minimize_nested_blockmodel_dl(): Fix bug with b_max and overlapping state

parent cb3e6d45
Pipeline #136 passed with stage
......@@ -471,7 +471,7 @@ def minimize_nested_blockmodel_dl(g, B_min=None, B_max=None, b_min=None,
if b_min is None:
b_min = min_state.b.fa
if Bs is None:
bs = [min_state.b.a, zeros(min_state.b.a.max() + 1, dtype="int")]
bs = [min_state.b.fa, zeros(min_state.b.fa.max() + 1, dtype="int")]
else:
bs = []
bstate = max_state
......@@ -484,6 +484,9 @@ def minimize_nested_blockmodel_dl(g, B_min=None, B_max=None, b_min=None,
bs.append(bstate.b.a)
bstate = bstate.get_block_state()
if layers:
state_args = overlay(state_args, overlap=overlap)
state = NestedBlockState(g, bs=bs,
base_type=type(min_state),
deg_corr=deg_corr,
......
......@@ -163,7 +163,7 @@ class NestedBlockState(object):
if l == 0:
raise ValueError("cannot delete level l=0")
b = self.project_partition(l, l - 1)
self.replace_level(l - 1, b)
self.replace_level(l - 1, b.fa)
del self.levels[l]
if _bm_test():
......@@ -296,13 +296,16 @@ class NestedBlockState(object):
state = self.levels[l]
if b_max is None:
b_max = state.g.vertex_index.copy("int").a
else:
b_max = b_max + (b_max.max() + 1) * clabel.fa
continuous_map(b_max)
max_state = state.copy(b=b_max, clabel=clabel)
if B_max is not None and max_state.B > B_max:
max_state = mcmc_multilevel(max_state, B_max,
**mcmc_multilevel_args)
if l < len(self.levels) - 1:
if B_min is None:
min_state = state.copy(b=clabel, clabel=clabel)
min_state = state.copy(b=clabel.fa, clabel=clabel.fa)
else:
min_state = mcmc_multilevel(max_state, B_min,
**mcmc_multilevel_args)
......@@ -310,17 +313,21 @@ class NestedBlockState(object):
assert min_state.B == self.levels[l+1].B, (min_state.B,
self.levels[l+1].N)
else:
min_state = state.copy(b=clabel, clabel=clabel)
min_state = state.copy(b=clabel.fa, clabel=clabel.fa)
if B_min is not None and min_state.B > B_min:
min_state = mcmc_multilevel(min_state, B_min,
**mcmc_multilevel_args)
if _bm_test():
assert min_state._check_clabel(), "invalid clabel %s" % str((l, self))
assert max_state._check_clabel(), "invalid clabel %s" % str((l, self))
# find new state
state = bisection_minimize([min_state, max_state], **bisection_args)
if _bm_test():
assert state.B >= min_state.B, (l, state.B, min_state.B, str(self))
assert state._check_clabel(), (l, str(self))
assert state._check_clabel(), "invalid clabel %s" % str((l, self))
return state
def draw(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment