Commit ecee9d6f authored by Tiago Peixoto's avatar Tiago Peixoto
Browse files

Fix problem with B_max in minimize_nested_blockmodel_dl()

This also removes (Nested)MinimizeState and checkpoints, in order to
simplify the code.
parent a9bc162f
Pipeline #104 failed with stage
......@@ -7,7 +7,6 @@
.. autoclass:: OverlapBlockState
.. autoclass:: CovariateBlockState
.. autofunction:: mcmc_sweep
.. autoclass:: MinimizeState
.. autofunction:: multilevel_minimize
.. autofunction:: collect_edge_marginals
.. autofunction:: collect_vertex_marginals
......@@ -19,7 +18,6 @@
.. autofunction:: condensation_graph
.. autofunction:: minimize_nested_blockmodel_dl
.. autoclass:: NestedBlockState
.. autoclass:: NestedMinimizeState
.. autofunction:: init_nested_state
.. autofunction:: nested_mcmc_sweep
.. autofunction:: nested_tree_sweep
......
......@@ -43,7 +43,6 @@ Summary
OverlapBlockState
CovariateBlockState
mcmc_sweep
MinimizeState
multilevel_minimize
collect_vertex_marginals
collect_edge_marginals
......@@ -66,7 +65,6 @@ Summary
minimize_nested_blockmodel_dl
NestedBlockState
NestedMinimizeState
init_nested_state
nested_mcmc_sweep
nested_tree_sweep
......@@ -105,7 +103,6 @@ __all__ = ["minimize_blockmodel_dl",
"OverlapBlockState",
"CovariateBlockState",
"mcmc_sweep",
"MinimizeState",
"multilevel_minimize",
"collect_edge_marginals",
"collect_vertex_marginals",
......@@ -117,7 +114,6 @@ __all__ = ["minimize_blockmodel_dl",
"condensation_graph",
"minimize_nested_blockmodel_dl",
"NestedBlockState",
"NestedMinimizeState",
"init_nested_state",
"nested_mcmc_sweep",
"nested_tree_sweep",
......@@ -128,14 +124,13 @@ __all__ = ["minimize_blockmodel_dl",
from . blockmodel import minimize_blockmodel_dl, BlockState, mcmc_sweep, \
multilevel_minimize, model_entropy, get_max_B, get_akc, condensation_graph, \
collect_edge_marginals, collect_vertex_marginals, bethe_entropy, mf_entropy, \
MinimizeState
collect_edge_marginals, collect_vertex_marginals, bethe_entropy, mf_entropy
from . overlap_blockmodel import OverlapBlockState, get_block_edge_gradient
from . covariate_blockmodel import CovariateBlockState
from . nested_blockmodel import NestedBlockState, NestedMinimizeState, \
from . nested_blockmodel import NestedBlockState, \
init_nested_state, nested_mcmc_sweep, nested_tree_sweep, \
minimize_nested_blockmodel_dl, get_hierarchy_tree
......
......@@ -1380,21 +1380,6 @@ def greedy_shrink(state, B, **kwargs):
return state
class MinimizeState(object):
r"""This object stores information regarding the current entropy minimization
state, so that the algorithms can resume previously started runs.
This object can be saved to disk via the :mod:`pickle` interface."""
def __init__(self):
self.b_cache = {}
self.checkpoint_state = defaultdict(dict)
self.init = True
def clear(self):
r"""Clear state."""
self.b_cache.clear()
self.checkpoint_state.clear()
def unilevel_minimize(state, nsweeps=10, adaptive_sweeps=True, epsilon=0,
anneal=(1., 1.), greedy=True, c=0., dl=False, dense=False,
multigraph=True, sequential=True, parallel=False,
......@@ -1545,8 +1530,8 @@ def unilevel_minimize(state, nsweeps=10, adaptive_sweeps=True, epsilon=0,
def multilevel_minimize(state, B, nsweeps=10, adaptive_sweeps=True, epsilon=0,
anneal=(1., 1.), r=2., nmerge_sweeps=10, greedy=True,
c=0., dl=False, dense=False, multigraph=True,
sequential=True, parallel=False, checkpoint=None,
minimize_state=None, verbose=False, **kwargs):
sequential=True, parallel=False, verbose=False,
**kwargs):
r"""Performs an agglomerative heuristic, which progressively merges blocks together (while allowing individual node moves) to achieve a good partition in ``B`` blocks.
Parameters
......@@ -1604,33 +1589,6 @@ def multilevel_minimize(state, B, nsweeps=10, adaptive_sweeps=True, epsilon=0,
vertices: ``list of ints`` (optional, default: ``None``)
A list of vertices which will be attempted to be moved. If ``None``, all
vertices will be attempted.
checkpoint : function (optional, default: ``None``)
If provided, this function will be called after each call to
:func:`mcmc_sweep`. This can be used to store the current state, so it
can be continued later. The function must have the following signature:
.. code-block:: python
def checkpoint(state, S, delta, nmoves, minimize_state):
...
where `state` is either a :class:`~graph_tool.community.BlockState`
instance or ``None``, `S` is the current entropy value, `delta` is
the entropy difference in the last MCMC sweep, and `nmoves` is the
number of accepted block membership moves. The ``minimize_state``
argument is a :class:`MinimizeState` instance which specifies the current
state of the algorithm, which can be stored via :mod:`pickle`, and
supplied via the ``minimize_state`` option below to continue from an
interrupted run.
This function will also be called when the MCMC has finished for the
current value of :math:`B`, in which case ``state == None``, and the
remaining parameters will be zero, except the last.
minimize_state : :class:`MinimizeState` (optional, default: ``None``)
If provided, this will specify an exact point of execution from which
the algorithm will continue. The expected object is a :class:`MinimizeState`
instance which will be passed to the callback of the ``checkpoint``
option above, and can be stored by :mod:`pickle`.
verbose : ``bool`` (optional, default: ``False``)
If ``True``, verbose information is displayed.
......@@ -1684,10 +1642,7 @@ def multilevel_minimize(state, B, nsweeps=10, adaptive_sweeps=True, epsilon=0,
:doi:`10.1103/PhysRevE.89.012804`, :arxiv:`1310.4378`.
"""
if minimize_state is None:
minimize_state = MinimizeState()
b_cache = minimize_state.b_cache
checkpoint_state = minimize_state.checkpoint_state
b_cache = kwargs.get("b_cache", {})
nkwargs = dict(nsweeps=nsweeps, epsilon=epsilon, c=c, dl=dl, dense=dense,
multigraph=multigraph, nmerge_sweeps=nmerge_sweeps,
......@@ -1749,7 +1704,7 @@ def multilevel_minimize(state, B, nsweeps=10, adaptive_sweeps=True, epsilon=0,
if verbose:
print("Minimizing for:", state.B)
dS, nmoves = unilevel_minimize(state, checkpoint=checkpoint, verbose=verbose, **kwargs)
dS, nmoves = unilevel_minimize(state, verbose=verbose, **kwargs)
if _bm_test():
assert state._BlockState__check_clabel(), "clabel invalidated after unilevel minimize!"
......@@ -1799,13 +1754,12 @@ def get_b_dl(state, dense, multigraph, nested_dl, complete=False,
return dl
def get_state_dl(B, minimize_state, checkpoint, sparse_heuristic, **kwargs):
bs = minimize_state.b_cache
checkpoint_state = minimize_state.checkpoint_state
def get_state_dl(B, b_cache, sparse_heuristic, **kwargs):
bs = b_cache
previous = None
verbose = kwargs.get("verbose", False)
if B in bs and checkpoint_state[B].get("done", False):
if B in bs:
# A previous finished result is available. Use that and keep going.
if verbose:
print("(using previous finished result for B=%d)" % B)
......@@ -1863,19 +1817,20 @@ def get_state_dl(B, minimize_state, checkpoint, sparse_heuristic, **kwargs):
# perform the actual minimization
args = kwargs.copy()
args["minimize_state"] = minimize_state
args["b_cache"] = bs
if sparse_heuristic:
args["dense"] = False
args["multigraph"] = False
#args["verbose"] = False
state = multilevel_minimize(state, B, checkpoint=checkpoint, **args)
state = multilevel_minimize(state, B, **args)
if _bm_test():
assert state._BlockState__check_clabel(), "clabel invalidated after minimize"
assert state.B == B
dl = get_b_dl(state, kwargs.get("dense", False),
dl = get_b_dl(state,
kwargs.get("dense", False),
kwargs.get("multigraph", False),
kwargs.get("nested_dl", False),
kwargs.get("complete", False),
......@@ -1907,8 +1862,6 @@ def get_state_dl(B, minimize_state, checkpoint, sparse_heuristic, **kwargs):
kwargs.get("dl_ent", False))
assert abs(dl - tdl) < 1e-8, "inconsistent DL values! (%g, %g)" % (dl, tdl)
checkpoint_state[B]["done"] = True
if _bm_test():
assert not isinf(dl)
return dl
......@@ -1941,8 +1894,7 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
adaptive_sweeps=True, epsilon=1e-3, anneal=(1., 1.),
greedy_cooling=True, sequential=True, parallel=False,
r=2, nmerge_sweeps=10, max_B=None, min_B=None,
mid_B=None, checkpoint=None, minimize_state=None,
random_bisection=False, exhaustive=False,
mid_B=None, random_bisection=False, exhaustive=False,
init_states=None, max_BE=None, verbose=False,
**kwargs):
r"""Find the block partition of an unspecified size which minimizes the description
......@@ -2036,34 +1988,6 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
mid_B : ``int`` (optional, default: ``None``)
Middle of the range which brackets the minimum. If not supplied, will be
automatically determined.
checkpoint : function (optional, default: ``None``)
If provided, this function will be called after each call to
:func:`mcmc_sweep`. This can be used to store the current state, so it
can be continued later. The function must have the following signature:
.. code-block:: python
def checkpoint(state, L, delta, nmoves, minimize_state):
...
where `state` is either a :class:`~graph_tool.community.BlockState`
instance or ``None``, `L` is the current description length, `delta` is
the entropy difference in the last MCMC sweep, and `nmoves` is the
number of accepted block membership moves. The ``minimize_state``
argument is a :class:`~graph_tool.community.MinimizeState` instance
which specifies the current state of the algorithm, which can be stored
via :mod:`pickle`, and supplied via the ``minimize_state`` option below
to continue from an interrupted run.
This function will also be called when the MCMC has finished for the
current value of :math:`B`, in which case ``state == None``, and the
remaining parameters will be zero, except the last.
minimize_state : :class:`~graph_tool.community.MinimizeState` (optional, default: ``None``)
If provided, this will specify an exact point of execution from which
the algorithm will continue. The expected object is a
:class:`~graph_tool.community.MinimizeState`
instance which will be passed to the callback of the ``checkpoint``
option above, and can be stored by :mod:`pickle`.
random_bisection : ``bool`` (optional, default: ``False``)
If ``True``, the best value of ``B`` will be found by performing a
random bisection, instead of a Fibonacci search.
......@@ -2209,10 +2133,9 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
dl_ent = kwargs.get("dl_ent", False)
ignore_degrees = kwargs.get("ignore_degrees", None)
if minimize_state is None:
minimize_state = MinimizeState()
b_cache = {}
if overlap and nonoverlap_init and minimize_state.init:
if overlap and nonoverlap_init:
if verbose:
print("Non-overlapping initialization...")
state = minimize_blockmodel_dl(g=g, ec=ec,
......@@ -2229,8 +2152,6 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
r=r, nmerge_sweeps=nmerge_sweeps,
max_B=max_B, min_B=min_B, mid_B=mid_B,
clabel=clabel if isinstance(clabel, PropertyMap) else None,
checkpoint=checkpoint,
minimize_state=minimize_state,
exhaustive=exhaustive, max_BE=max_BE,
nested_dl=nested_dl, overlap=False,
init_states=None, dl_ent=dl_ent,
......@@ -2241,9 +2162,6 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
max_B = state.B
init_states = [state]
minimize_state.clear()
minimize_state.init = False
if min_B is None:
min_B = state.clabel.fa.max() + 1
......@@ -2276,25 +2194,23 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
greedy = greedy_cooling
shrink = True
b_cache = minimize_state.b_cache
checkpoint_state = minimize_state.checkpoint_state
kwargs = dict(nsweeps=nsweeps, adaptive_sweeps=adaptive_sweeps, c=c,
sequential=sequential, parallel=parallel, shrink=shrink, r=r,
anneal=anneal, greedy=greedy, epsilon=epsilon,
nmerge_sweeps=nmerge_sweeps, deg_corr=deg_corr, dense=dense,
multigraph=multigraph, dl=dl,
sparse_heuristic=sparse_heuristic, checkpoint=checkpoint,
minimize_state=minimize_state, nested_dl=nested_dl,
sparse_heuristic=sparse_heuristic, nested_dl=nested_dl,
nested_overlap=nested_overlap,
nonoverlap_compare=nonoverlap_compare, dl_ent=dl_ent,
confine_layers=confine_layers, verbose=verbose)
confine_layers=confine_layers, b_cache=b_cache,
verbose=verbose)
if init_states is not None:
for state in init_states:
if _bm_test():
assert state._BlockState__check_clabel(), "init state has invalid clabel!"
dl = get_b_dl(state, kwargs.get("dense", False),
dl = get_b_dl(state,
kwargs.get("dense", False),
kwargs.get("multigraph", False),
kwargs.get("nested_dl", False),
kwargs.get("complete", False),
......@@ -2310,13 +2226,15 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
if B_init:
if ec is None:
if overlap:
state = OverlapBlockState(g, B=2 * g.num_edges(), deg_corr=deg_corr,
vweight=vweight, eweight=eweight,
clabel=clabel, max_BE=max_BE)
state = OverlapBlockState(g, B=2 * g.num_edges(),
deg_corr=deg_corr, vweight=vweight,
eweight=eweight, clabel=clabel,
max_BE=max_BE)
else:
state = BlockState(g, B=g.num_vertices(), deg_corr=deg_corr,
vweight=vweight, eweight=eweight, clabel=clabel,
max_BE=max_BE, ignore_degrees=ignore_degrees)
vweight=vweight, eweight=eweight,
clabel=clabel, max_BE=max_BE,
ignore_degrees=ignore_degrees)
else:
if overlap:
......@@ -2342,7 +2260,8 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
if _bm_test():
assert state._BlockState__check_clabel(), "clabel invalid at creation!"
dl = get_b_dl(state, kwargs.get("dense", False),
dl = get_b_dl(state,
kwargs.get("dense", False),
kwargs.get("multigraph", False),
kwargs.get("nested_dl", False),
kwargs.get("complete", False),
......@@ -2359,8 +2278,7 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
for B in reversed(range(min_B, max_B + 1)):
if B in b_cache:
state = b_cache[B][1]
if checkpoint_state[B].get("done", False):
continue
continue
args = kwargs.copy()
if sparse_heuristic:
......@@ -2420,9 +2338,6 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
if verbose:
print("Current bracket:", (min_B, mid_B, max_B), (f_min, f_mid, f_max))
if checkpoint is not None:
checkpoint(None, 0, 0, 0, minimize_state)
cleanup_cache(b_cache, min_B, max_B)
if f_mid > f_min or f_mid > f_max:
......@@ -2468,9 +2383,6 @@ def minimize_blockmodel_dl(g, deg_corr=True, overlap=False, ec=None,
return b_cache[best_B][1]
if checkpoint is not None:
checkpoint(None, 0, 0, 0, minimize_state)
if f_x < f_mid:
if max_B - mid_B > mid_B - min_B:
min_B = mid_B
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment